diff --git a/api/types.go b/api/types.go index 8f12b5f9..9e5991dc 100644 --- a/api/types.go +++ b/api/types.go @@ -153,6 +153,7 @@ type Options struct { NumCtx int `json:"num_ctx,omitempty"` NumKeep int `json:"num_keep,omitempty"` NumBatch int `json:"num_batch,omitempty"` + NumGQA int `json:"num_gqa,omitempty"` NumGPU int `json:"num_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"` LowVRAM bool `json:"low_vram,omitempty"` @@ -190,6 +191,7 @@ func DefaultOptions() Options { NumCtx: 2048, NumBatch: 1024, NumGPU: 1, + NumGQA: 1, LowVRAM: false, F16KV: true, UseMMap: true, diff --git a/llama/ggml-cuda.cu b/llama/ggml-cuda.cu index 0b29c815..e41ed50f 100644 --- a/llama/ggml-cuda.cu +++ b/llama/ggml-cuda.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -246,7 +246,7 @@ typedef struct { static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); #define WARP_SIZE 32 -#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 @@ -358,12 +358,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-6f; - float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { @@ -961,12 +959,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint8_t * q1 = x[i].qs + q_offset; - const uint8_t * q2 = q1 + 64; const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -979,14 +983,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + float4 s = {0.f, 0.f, 0.f, 0.f}; float smin = 0; - for (int l = 0; l < n; ++l) { - s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); - s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; } - tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif } #else @@ -1066,10 +1097,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * ql2 = ql1 + 64; const uint8_t * qh = x[i].qh + l0; const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -1085,15 +1118,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, float4 sum = {0.f, 0.f, 0.f, 0.f}; float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } @@ -1547,33 +1590,95 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_K * bq4_K = (const block_q4_K *) vbq; - const int bq8_offset = QR4_K * (iqs / QI8_1); - float sumf_d = 0.0f; float sumf_m = 0.0f; +#ifndef GGML_QKK_64 + + // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * (iqs / (QI8_1/2)); + const float d = bq4_K->d; const float dmin = bq4_K->dmin; - const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4)); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; for (int i = 0; i < QR4_K; ++i) { - const int isc = bq8_offset + i; - - uint8_t sc, m; - get_scale_min_k4(isc, bq4_K->scales, sc, m); const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); const float d8i = bq8i->d; + const int * q8 = (const int *)bq8i->qs + (iqs%4); + const int ui1 = q8[0]; + const int ui2 = q8[4]; - const int vi = (v >> (4*i)) & 0x0F0F0F0F; + const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F; + const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F; - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + + sumf_d += d8i * (dot1 * sc[i]); + sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values } return d*sumf_d - dmin*sumf_m; + +#else + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->d[0]; + const float dmin = bq4_K->d[1]; + + const float d8_1 = bq8_1[0].d; + const float d8_2 = bq8_1[1].d; + + const int ui1 = *((const int *)bq8_1[0].qs + iqs); + const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); + const int ui3 = *((const int *)bq8_1[1].qs + iqs); + const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + + const int * q4 = (const int *)bq4_K->qs + iqs; + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#endif + #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1585,7 +1690,11 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_K * bq5_K = (const block_q5_K *) vbq; - const int bq8_offset = QR5_K * (iqs / QI8_1); +#ifndef GGML_QKK_64 + + const int bq8_offset = QR5_K * (iqs / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4)); float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -1593,31 +1702,87 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const float d = bq5_K->d; const float dmin = bq5_K->dmin; - const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]); + const int vl1 = ql[0]; + const int vl2 = ql[4]; - const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset; + const int vh1 = qh[0] >> bq8_offset; + const int vh2 = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; for (int i = 0; i < QR5_K; ++i) { - const int isc = bq8_offset + i; - - uint8_t sc, m; - get_scale_min_k4(isc, bq5_K->scales, sc, m); const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); const float d8i = bq8i->d; + const int * q8 = (const int *)bq8i->qs + (iqs%4); + const int ui1 = q8[0]; + const int ui2 = q8[4]; - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F; + const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F; - const int vih = ((vh >> i) << 4) & 0x10101010; + const int vih1 = ((vh1 >> i) << 4) & 0x10101010; + const int vih2 = ((vh2 >> i) << 4) & 0x10101010; - const int vi = vil | vih; + const int vi1 = vil1 | vih1; + const int vi2 = vil2 | vih2; + + const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + + sumf_d += d8i * (dot1 * sc[i]); + sumf_m += d8i * (dot2 * m[i]); - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values } return d*sumf_d - dmin*sumf_m; + +#else + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = bq8_1[0].d; + const float d8_2 = bq8_1[1].d; + + const int ui1 = *((const int *)bq8_1[0].qs + iqs); + const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); + const int ui3 = *((const int *)bq8_1[1].qs + iqs); + const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + + const int * ql = (const int *)bq5_K->qs + iqs; + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * iqs; // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; + +#endif + #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1771,11 +1936,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } } -static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1791,7 +1960,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const } // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -1819,12 +1988,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1841,7 +2011,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -2053,10 +2223,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i norm_f32<<>>(x, dst, ncols); } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); + rms_norm_f32<<>>(x, dst, ncols, eps); } static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { @@ -2285,7 +2455,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per + // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales + // is better amortized. + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2294,7 +2467,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per + // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales + // is better amortized. + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2350,20 +2526,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); } static void ggml_mul_mat_vec_nc_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); } static void ggml_cpy_f32_f32_cuda( @@ -2449,20 +2628,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; CUDA_CHECK(cudaGetDevice(&id)); - +#ifdef DEBUG_CUDA_MALLOC + int nnz = 0; + size_t max_size = 0, tot_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.size >= size && b.ptr != nullptr) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; + if (b.ptr != nullptr) { +#ifdef DEBUG_CUDA_MALLOC + ++nnz; + tot_size += b.size; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } } } + if (ibest >= 0) { + cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } +#ifdef DEBUG_CUDA_MALLOC + fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); +#endif void * ptr; - CUDA_CHECK(cudaMalloc((void **) &ptr, size)); - *actual_size = size; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); + CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + *actual_size = look_ahead_size; return ptr; } @@ -2490,7 +2702,9 @@ static size_t g_scratch_offset = 0; static int g_device_count = -1; static int g_main_device = 0; +#ifndef GGML_CUDA_FORCE_DMMV static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; +#endif static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -2513,7 +2727,9 @@ void ggml_init_cublas() { g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; +#ifndef GGML_CUDA_FORCE_DMMV g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; +#endif } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -2538,6 +2754,9 @@ void ggml_init_cublas() { } void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } bool all_zero = true; for (int i = 0; i < g_device_count; ++i) { if (tensor_split[i] != 0.0f) { @@ -2678,6 +2897,7 @@ inline void ggml_cuda_op_mul( (void) dst; (void) src0_ddq_i; (void) i02; + (void) i1; } inline void ggml_cuda_op_gelu( @@ -2757,8 +2977,11 @@ inline void ggml_cuda_op_rms_norm( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); (void) src1; (void) dst; @@ -2805,8 +3028,8 @@ inline void ggml_cuda_op_mul_mat_vec( #endif if (use_mul_mat_vec_q) { - int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1; - padded_row_size -= padded_row_size % MATRIX_ROW_PADDING; + const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ? + ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; size_t as; void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main); @@ -2973,13 +3196,18 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + // RoPE alteration for extended context - const float theta_scale = powf(10000.0, -2.0f/n_dims); - const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + float freq_base, freq_scale; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; bool is_glm = mode & 4; @@ -2992,6 +3220,7 @@ inline void ggml_cuda_op_rope( rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); } + (void) src1; (void) dst; (void) src0_ddq_i; (void) src1_ddf_i; @@ -3010,11 +3239,12 @@ inline void ggml_cuda_op_diag_mask_inf( const int64_t ne01 = src0->ne[1]; const int64_t i01_diff = i01_high - i01_low; - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); + (void) src1; (void) dst; (void) src0_ddq_i; (void) src1_ddf_i; @@ -3082,6 +3312,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t ne11 = use_src1 ? src1->ne[1] : 1; const int64_t ne12 = use_src1 ? src1->ne[2] : 1; const int64_t ne13 = use_src1 ? src1->ne[3] : 1; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(ne03 == ne13); const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -3093,12 +3326,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; - const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13; + const int64_t num_iters = flatten_rows ? 1 : num_iters_0; + const int64_t stride_mod = flatten_rows ? num_iters_0 : 1; const int64_t src0_stride = ne00 * ne01 * stride_mod; const int64_t src1_stride = ne10 * ne11 * stride_mod; const int64_t dst_stride = ne0 * ne1 * stride_mod; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12); + const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02; + GGML_ASSERT(!(flatten_rows && ne02 < ne12)); + const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -3115,6 +3355,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 < ne12)); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -3151,7 +3392,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; } else { row_low = 0; - row_high = nrows0; + row_high = nrows0*i02_divisor; } if (row_low == row_high) { continue; @@ -3199,16 +3440,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - const int64_t i03_max = flatten_rows ? 1 : ne03; - const int64_t i02_max = flatten_rows ? 1 : ne02; - const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; - for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; - const int64_t i0 = i03*ne02 + i02; + const int64_t i0 = i03*i02_max + i02; // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs const int64_t i0_offset_low = row_low/rows_per_iter; @@ -3242,10 +3479,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t i11 = i13*ne12 + i12; // for split tensors the data begins at i0 == i0_offset_low - char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; - float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; + char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs; + float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first @@ -3284,11 +3521,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } } - if (!src0_on_device || !src0_is_contiguous) { + if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } } @@ -3442,6 +3679,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; @@ -3454,7 +3693,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -3468,6 +3707,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; @@ -3486,7 +3727,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int row_stride_x = nb01 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -3627,7 +3868,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { size_t size = ggml_nbytes_split(tensor, nrows_split); const size_t original_size = size; - // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses if (ne0 % MATRIX_ROW_PADDING != 0) { size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); @@ -3643,7 +3884,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { } - CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); extra->data_device[id] = buf; @@ -3723,7 +3964,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; if (tensor->op == GGML_OP_VIEW) { - memcpy(&offset, tensor->src[2]->data, sizeof(size_t)); + memcpy(&offset, tensor->op_params, sizeof(size_t)); } extra = ggml_cuda_alloc_temp_tensor_extra(); extra->data_device[g_main_device] = src0_ddc + offset; @@ -3825,18 +4066,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul; break; - case GGML_OP_GELU: - if (!any_on_device) { - return false; - } - func = ggml_cuda_gelu; - break; - case GGML_OP_SILU: - if (!any_on_device) { - return false; - } - func = ggml_cuda_silu; - break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_gelu; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_silu; + break; + default: + return false; + } break; case GGML_OP_NORM: if (!any_on_device) { return false; diff --git a/llama/ggml-cuda.h b/llama/ggml-cuda.h index 445c5686..27aae6dc 100644 --- a/llama/ggml-cuda.h +++ b/llama/ggml-cuda.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/ggml-metal.h b/llama/ggml-metal.h index 3f8d67f3..57264feb 100644 --- a/llama/ggml-metal.h +++ b/llama/ggml-metal.h @@ -1,7 +1,7 @@ //go:build darwin /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -89,6 +89,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * // get data from the device into host memory void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); +// try to find operations that can be run concurrently in the graph +// you should run it again if the topology of your graph changes +void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); + +// if the graph has been optimized for concurrently dispatch +bool ggml_metal_if_optimized(struct ggml_metal_context * ctx); + // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); diff --git a/llama/ggml-metal.m b/llama/ggml-metal.m index 4bd55fc5..709feea4 100644 --- a/llama/ggml-metal.m +++ b/llama/ggml-metal.m @@ -1,7 +1,7 @@ //go:build darwin /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -64,12 +64,16 @@ struct ggml_metal_context { int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + int concur_list[GGML_MAX_NODES]; + int concur_list_len; + // custom kernels #define GGML_METAL_DECL_KERNEL(name) \ id function_##name; \ id pipeline_##name GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); @@ -125,6 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->device = MTLCreateSystemDefaultDevice(); ctx->queue = [ctx->device newCommandQueue]; ctx->n_buffers = 0; + ctx->concur_list_len = 0; // determine if we can use MPS if (MPSSupportsMTLDevice(ctx->device)) { @@ -185,6 +190,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); @@ -243,6 +249,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { ctx->n_cb = n_cb; } +bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) { + if (ctx->concur_list_len) { + return true; + } + return false; +} + // finds the Metal buffer that contains the tensor data on the GPU device // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer @@ -381,11 +394,98 @@ void ggml_metal_get_tensor( memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t)); } +void ggml_metal_graph_find_concurrency( + struct ggml_metal_context * ctx, + struct ggml_cgraph * gf) { + int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time + int nodes_unused[GGML_MAX_NODES]; + + for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;} + for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;} + ctx->concur_list_len = 0; + + int n_left = gf->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list + int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos + + while (n_left > 0) { + // number of nodes at a layer (that can be issued concurrently) + int concurrency = 0; + for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { + if (nodes_unused[i]) { + // if the requirements for gf->nodes[i] are satisfied + int exe_flag=1; + // scan all srcs + for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { + struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; + if (src_cur) { + // if is leaf nodes it's satisfied. + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} + + // otherwise this src should be the output from previous nodes. + int is_found = 0; + // scan 2*search_depth back because we inserted barrier. + for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;} + } + if (is_found == 0) {exe_flag = 0; break;} + } + } + if (exe_flag) { + // check if nodes[i]'s data will be overwritten by a node before nodes[i]. + // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] + int64_t data_start = (int64_t) gf->nodes[i]->data; + int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); + for (int j = n_start; j < i; j++) { + if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ + && gf->nodes[j]->op != GGML_OP_VIEW \ + && gf->nodes[j]->op != GGML_OP_TRANSPOSE \ + && gf->nodes[j]->op != GGML_OP_PERMUTE) { + if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ + ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { + continue; + } else { + exe_flag = 0; + } + } + } + } + if (exe_flag) { + ctx->concur_list[level_pos + concurrency] = i; + nodes_unused[i] = 0; + concurrency++; + ctx->concur_list_len++; + } + } + } + n_left -= concurrency; + // adding a barrier different layer + ctx->concur_list[level_pos + concurrency] = -1; + ctx->concur_list_len++; + // jump all sorted nodes at nodes_bak + while (!nodes_unused[n_start]) {n_start++;} + level_pos += concurrency + 1; + } + + if (ctx->concur_list_len > GGML_MAX_NODES) { + fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); + } +} + void ggml_metal_graph_compute( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { metal_printf("%s: evaluating graph\n", __func__); + // if there is ctx->concur_list, dispatch concurrently + // else fallback to serial dispatch + MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; + + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; + + const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; + edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; + // create multiple command buffers and enqueue them // then, we encode the graph into the command buffers in parallel @@ -404,7 +504,7 @@ void ggml_metal_graph_compute( dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb; + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; dispatch_async(queue, ^{ size_t offs_src0 = 0; @@ -415,10 +515,21 @@ void ggml_metal_graph_compute( id encoder = nil; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; + + for (int ind = node_start; ind < node_end; ++ind) { + const int i = has_concur ? ctx->concur_list[ind] : ind; + + if (i == -1) { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + continue; + } + [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; + continue; + } - for (int i = node_start; i < node_end; ++i) { metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); struct ggml_tensor * src0 = gf->nodes[i]->src[0]; @@ -489,13 +600,19 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } - [encoder setComputePipelineState:ctx->pipeline_add]; + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_add_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_add]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -504,7 +621,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } if (ggml_nelements(src1) == ne10) { @@ -525,7 +642,7 @@ void ggml_metal_graph_compute( case GGML_OP_SCALE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float scale = *(const float *) src1->data; @@ -539,52 +656,60 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; - case GGML_OP_SILU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } + case GGML_OP_UNARY: + switch (ggml_get_unary_op(gf->nodes[i])) { + case GGML_UNARY_OP_SILU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + } - [encoder setComputePipelineState:ctx->pipeline_silu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + } + + [encoder setComputePipelineState:ctx->pipeline_relu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } } break; - case GGML_OP_RELU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_GELU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32; @@ -602,10 +727,10 @@ void ggml_metal_graph_compute( case GGML_OP_DIAG_MASK_INF: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *)(dst->op_params))[0]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -665,7 +790,7 @@ void ggml_metal_graph_compute( } } else { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } int nth0 = 32; @@ -704,8 +829,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: @@ -713,8 +838,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: @@ -768,19 +893,21 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_Q3_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -790,7 +917,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } switch (src0->type) { @@ -819,10 +946,11 @@ void ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } - const float eps = 1e-6f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 512; @@ -841,7 +969,7 @@ void ggml_metal_graph_compute( case GGML_OP_NORM: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const float eps = 1e-5f; @@ -863,14 +991,15 @@ void ggml_metal_graph_compute( case GGML_OP_ALIBI: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } GGML_ASSERT((src0t == GGML_TYPE_F32)); - const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); if (__builtin_popcount(n_head) != 1) { GGML_ASSERT(false && "only power-of-two n_head implemented"); @@ -905,18 +1034,17 @@ void ggml_metal_graph_compute( case GGML_OP_ROPE: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; float freq_base; float freq_scale; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -945,10 +1073,12 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_DUP: case GGML_OP_CPY: + case GGML_OP_CONT: { if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; } const int nth = 32; @@ -995,8 +1125,10 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; default: - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); + { + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } } } diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal index 63e08058..b8397d9f 100644 --- a/llama/ggml-metal.metal +++ b/llama/ggml-metal.metal @@ -1,7 +1,7 @@ //go:build darwin /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -95,6 +95,17 @@ kernel void kernel_add( dst[tpig] = src0[tpig] + src1[tpig]; } +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % ne00]; +} + kernel void kernel_mul( device const float * src0, device const float * src1, @@ -379,7 +390,7 @@ kernel void kernel_rms_norm( threadgroup_barrier(mem_flags::mem_threadgroup); // broadcast, simd group number is ntg / 32 - for (int i = ntg / 32 / 2; i > 0; i /= 2) { + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { sum[tpitg] += sum[tpitg + i]; } @@ -404,87 +415,90 @@ kernel void kernel_rms_norm( } } -// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 1); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); + return d * (sumy * -8.f + acc[0] + acc[1]); } -// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 2); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -template +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// giard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + const int first_row = (r0 * nsg + sgitg) * nr; + device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; device const float * y = (device const float *) src1 + r1*ne10; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; + float yl[16]; // src1 vector cache + float sumf[nr]={0.f}; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + const int ix = tiisg/2; + const int il = 8*(tiisg%2); + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; } - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } + + yb += QK4_0 * 16; } - // from now loads two rows every time and 16 blocks per row - int ir = tiisg / (N_SIMDWIDTH / 2); - int ib = tiisg % (N_SIMDWIDTH / 2); - for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { - int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row+=2) { - if (nb_start + ib < nb) { - sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); - } - } - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + first_row + row] = tot; } } } @@ -500,7 +514,7 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -514,7 +528,7 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -1237,111 +1251,137 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - float sumf = 0; + const int step = sizeof(block_q2_K) * nb; #if QK_K == 256 - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int ip = il/2; // 0 or 1 - const int shift1 = 4*(il%2);// 0 or 4 - const int shift2 = shift1+2;// 2 or 6 - const int n = 8; - const int is = 4*il + (n*ir)/16; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 - const int y_offset = 64*il + n*ir; - const int q_offset = 32*ip + n*ir; + device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int ib = ix; ib < nb; ib += 4) { - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * scales = x[i].scales + is; - - uint8_t d1 = scales[0] & 0xF; - uint8_t d2 = scales[2] & 0xF; - uint8_t m1 = scales[0] >> 4; - uint8_t m2 = scales[2] >> 4; - - device const float * y = yy + i*QK_K + y_offset; - - float2 s = {0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); - s[1] += y[l+32] * ((q[l] >> shift2) & 3); - smin += y[l+ 0] * m1 + y[l+32] * m2; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; + for (int row = 0; row < N_DST; row++) { + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 - uint32_t aux[2]; - thread const uint8_t * d = (thread const uint8_t *)aux; - thread const uint8_t * m = (thread const uint8_t *)aux + 4; + device const float * y4 = y + ix * QK_K + 8 * it; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int ib = ix; ib < nb; ib += 16) { - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i*QK_K + il; - - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; - - device const uint32_t * a = (device const uint32_t *)x[i].scales; - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = (a[0] >> 4) & 0x0f0f0f0f; - - for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) - + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) - + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) - + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 16 * QK_K; } #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } } } +#if QK_K == 256 kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, @@ -1350,40 +1390,41 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - -#if QK_K == 256 - - const uint8_t m3 = 3; - const int8_t m4 = 4; + float yl[16]; const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = tpitg.y; // expecting 16 + const int tid = tiisg/2; + const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 const int ir = tid%2; const int n = 8; const int l0 = n*ir; - const uint8_t m = 1 << (4*ip + il); + const uint16_t m1 = 1 << (4*ip + il); + const uint16_t m2 = m1 << 8; const int shift = 2*il; + const uint16_t qm1 = 0x0003 << shift; + const uint16_t qm2 = 0x0300 << shift; + const int32_t v1 = 4 << shift; + const int32_t v2 = 1024 << shift; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + 2*(il/2); @@ -1392,93 +1433,132 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - //float sumf = 0; - float sumf1 = 0, sumf2 = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + const int step = sizeof(block_q3_K) * nb / 2; - const float d_all = (float)(x[i].d); + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * h = x[i].hmask + l0; - device const float * y = yy + i * QK_K + y_offset; + float sumf1[2] = {0.f}, sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 2) { - device const uint16_t * a = (device const uint16_t *)x[i].scales; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); - - float s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; + yl[l+8] = y1[l+16]; } - float d = d_all * s; - sumf1 += d * scales[0]; - sumf2 += d; - //sumf += d_all * s * (scales[0] - 32); - s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + + float s1 = 0, s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); + } + float d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[0]; + sumf2[row] += d; + + s1 = s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2+8]; + s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); + s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + } + d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[1]; + sumf2[row] += d; + + q += step; + h += step; + a += step; + dh += step; + } - d = d_all * s; - sumf1 += d * scales[1]; - sumf2 += d; - //sumf += d_all * s * (scales[1] - 32); + + y1 += 2 * QK_K; } - //sum[ith] = sumf; - sum[ith] = sumf1 - 32.f*sumf2; + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } + } +} #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 +kernel void kernel_mul_mat_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne1, + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + const int row = 2 * r0 + sgitg; + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 const int in = il%8; // 0, 4, 0, 4 - float sumf = 0; + float2 sum = {0.f, 0.f}; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int i = ix; i < nb; i += 8) { const float d_all = (float)(x[i].d); - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].hmask + in; - device const float * y = yy + i * QK_K + il; + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - for (int l = 0; l < 4; ++l) { - const uint8_t hm = h[l] >> im; - sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) - + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) - + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) - + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> im; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); } } + const float sumf = sum[0] + sum[1] * 1.f/256.f; - sum[ith] = sumf; - -#endif - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } } +#endif #if QK_K == 256 kernel void kernel_mul_mat_q4_K_f32( @@ -1776,7 +1856,6 @@ kernel void kernel_mul_mat_q5_K_f32( for (int i = ix; i < nb; i += 8) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { yl[l+0] = y[l+ 0]; yl[l+4] = y[l+16]; diff --git a/llama/ggml-mpi.c b/llama/ggml-mpi.c index 06d47c15..82b93bb3 100644 --- a/llama/ggml-mpi.c +++ b/llama/ggml-mpi.c @@ -1,7 +1,7 @@ //go:build mpi /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/ggml-mpi.h b/llama/ggml-mpi.h index 35d7b175..61862740 100644 --- a/llama/ggml-mpi.h +++ b/llama/ggml-mpi.h @@ -1,7 +1,7 @@ //go:build mpi /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/ggml-opencl.cpp b/llama/ggml-opencl.cpp index d14e367d..b392f0b3 100644 --- a/llama/ggml-opencl.cpp +++ b/llama/ggml-opencl.cpp @@ -1,7 +1,7 @@ //go:build opencl /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/ggml-opencl.h b/llama/ggml-opencl.h index 5becb26a..94043893 100644 --- a/llama/ggml-opencl.h +++ b/llama/ggml-opencl.h @@ -1,7 +1,7 @@ //go:build opencl /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/ggml.c b/llama/ggml.c index 1f3bf620..cc9c594f 100644 --- a/llama/ggml.c +++ b/llama/ggml.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -3466,7 +3466,9 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(GGML_SIMD) +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &v, y, 1, n); +#elif defined(GGML_SIMD) const int np = (n & ~(GGML_F32_STEP - 1)); GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); @@ -3629,7 +3631,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } -inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) { +inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) { sum += (ggml_float)x[i]; @@ -3637,6 +3639,14 @@ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x *s = sum; } +inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_FP16_TO_FP32(x[i]); + } + *s = sum; +} + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -3776,16 +3786,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGMAX", "REPEAT", "REPEAT_BACK", - "ABS", - "SGN", - "NEG", - "STEP", - "TANH", - "ELU", - "RELU", - "GELU", - "GELU_QUICK", - "SILU", "SILU_BACK", "NORM", "RMS_NORM", @@ -3824,6 +3824,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_PART", "WIN_UNPART", + "UNARY", + "MAP_UNARY", "MAP_BINARY", @@ -3835,7 +3837,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3856,16 +3858,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argmax(x)", "repeat(x)", "repeat_back(x)", - "abs(x)", - "sgn(x)", - "-x", - "step(x)", - "tanh(x)", - "elu(x)", - "relu(x)", - "gelu(x)", - "gelu_quick(x)", - "silu(x)", "silu_back(x)", "norm(x)", "rms_norm(x)", @@ -3904,6 +3896,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_part(x)", "win_unpart(x)", + "unary(x)", + "f(x)", "f(x,y)", @@ -3915,7 +3909,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4103,8 +4097,8 @@ bool ggml_is_numa(void) { //////////////////////////////////////////////////////////////////////////////// void ggml_print_object(const struct ggml_object * obj) { - GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", - obj->offs, obj->size, (const void *) obj->next); + GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + obj->type, obj->offs, obj->size, (const void *) obj->next); } void ggml_print_objects(const struct ggml_context * ctx) { @@ -4171,6 +4165,10 @@ const char * ggml_op_name(enum ggml_op op) { return GGML_OP_NAME[op]; } +const char * ggml_op_symbol(enum ggml_op op) { + return GGML_OP_SYMBOL[op]; +} + size_t ggml_element_size(const struct ggml_tensor * tensor) { return GGML_TYPE_SIZE[tensor->type]; } @@ -4240,7 +4238,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { } size_t ggml_tensor_overhead(void) { - return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16; + return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; } bool ggml_is_transposed(const struct ggml_tensor * tensor) { @@ -4257,6 +4255,15 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -4402,7 +4409,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { return NULL; } - const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1); + const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); *ctx = (struct ggml_context) { /*.mem_size =*/ mem_size, @@ -4469,6 +4476,10 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) return result; } +bool ggml_get_no_alloc(struct ggml_context * ctx) { + return ctx->no_alloc; +} + void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { ctx->no_alloc = no_alloc; } @@ -4487,12 +4498,14 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { struct ggml_object * obj = ctx->objects_begin; while (obj != NULL) { - struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); - const size_t size = ggml_nbytes(tensor); + const size_t size = ggml_nbytes(tensor); - if (max_size < size) { - max_size = size; + if (max_size < size) { + max_size = size; + } } obj = obj->next; @@ -4506,7 +4519,7 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { // this is an error prone process, but it is necessary to support inplace // operators when using scratch buffers // TODO: implement a better way -void ggml_scratch_save(struct ggml_context * ctx) { +static void ggml_scratch_save(struct ggml_context * ctx) { // this is needed to allow opt tensors to store their data // TODO: again, need to find a better way ctx->no_alloc_save = ctx->no_alloc; @@ -4516,7 +4529,7 @@ void ggml_scratch_save(struct ggml_context * ctx) { ctx->scratch.data = NULL; } -void ggml_scratch_load(struct ggml_context * ctx) { +static void ggml_scratch_load(struct ggml_context * ctx) { ctx->no_alloc = ctx->no_alloc_save; ctx->scratch = ctx->scratch_save; @@ -4524,12 +4537,7 @@ void ggml_scratch_load(struct ggml_context * ctx) { //////////////////////////////////////////////////////////////////////////////// -struct ggml_tensor * ggml_new_tensor_impl( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int64_t* ne, - void* data) { +static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { // always insert objects at the end of the context's memory pool struct ggml_object * obj_cur = ctx->objects_end; @@ -4537,63 +4545,28 @@ struct ggml_tensor * ggml_new_tensor_impl( const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; const size_t cur_end = cur_offs + cur_size; - size_t size_needed = 0; - - if (data == NULL && !ctx->no_alloc) { - size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]); - for (int i = 1; i < n_dims; i++) { - size_needed *= ne[i]; - } - // align to GGML_MEM_ALIGN - size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; - } + // align to GGML_MEM_ALIGN + size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); char * const mem_buffer = ctx->mem_buffer; struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); - if (ctx->scratch.data == NULL || data != NULL) { - size_needed += GGML_TENSOR_SIZE; - - if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - *obj_new = (struct ggml_object) { - .offs = cur_end + GGML_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - }; - } else { - if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", - __func__, ctx->scratch.offs + size_needed, ctx->scratch.size); - assert(false); - return NULL; - } - - if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - *obj_new = (struct ggml_object) { - .offs = cur_end + GGML_OBJECT_SIZE, - .size = GGML_TENSOR_SIZE, - .next = NULL, - }; - - //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); - - ctx->scratch.offs += size_needed; + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed, ctx->mem_size); + assert(false); + return NULL; } + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + .type = type, + }; + + ggml_assert_aligned(mem_buffer + obj_new->offs); + if (obj_cur != NULL) { obj_cur->next = obj_new; } else { @@ -4605,9 +4578,46 @@ struct ggml_tensor * ggml_new_tensor_impl( //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); - struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs); + return obj_new; +} - ggml_assert_aligned(result); +static struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t* ne, + void* data) { + + size_t data_size = 0; + + if (data == NULL && !ctx->no_alloc) { + data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { + data_size *= ne[i]; + } + } + + if (ctx->scratch.data != NULL && data == NULL) { + // allocate tensor data in the scratch buffer + if (ctx->scratch.offs + data_size > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + data_size, ctx->scratch.size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + ctx->scratch.offs += data_size; + + data_size = 0; + } + + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + data_size); + + // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here + + struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); *result = (struct ggml_tensor) { /*.type =*/ type, @@ -4616,6 +4626,7 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ GGML_OP_NONE, + /*.op_params =*/ {0}, /*.is_param =*/ false, /*.grad =*/ NULL, /*.src =*/ { NULL }, @@ -4646,6 +4657,21 @@ struct ggml_tensor * ggml_new_tensor_impl( return result; } +static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { + assert(params_size <= GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + +static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + return ((const int32_t *)(tensor->op_params))[i]; +} + +static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + ((int32_t *)(tensor->op_params))[i] = value; +} + struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, enum ggml_type type, @@ -4977,6 +5003,11 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) { return (float *)(tensor->data); } +enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_UNARY); + return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); +} + const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } @@ -5015,9 +5046,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam char * const mem_buffer = ctx->mem_buffer; while (obj != NULL) { - struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); - if (strcmp(cur->name, name) == 0) { - return cur; + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } } obj = obj->next; @@ -5030,7 +5063,7 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam // ggml_dup -struct ggml_tensor * ggml_dup_impl( +static struct ggml_tensor * ggml_dup_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -5045,7 +5078,6 @@ struct ggml_tensor * ggml_dup_impl( result->op = GGML_OP_DUP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5064,7 +5096,7 @@ struct ggml_tensor * ggml_dup_inplace( // ggml_add -struct ggml_tensor * ggml_add_impl( +static struct ggml_tensor * ggml_add_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5107,7 +5139,7 @@ struct ggml_tensor * ggml_add_inplace( // ggml_add1 -struct ggml_tensor * ggml_add1_impl( +static struct ggml_tensor * ggml_add1_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5147,7 +5179,7 @@ struct ggml_tensor * ggml_add1_inplace( // ggml_acc -struct ggml_tensor * ggml_acc_impl( +static struct ggml_tensor * ggml_acc_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5169,23 +5201,13 @@ struct ggml_tensor * ggml_acc_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5); - - ((int32_t *) c->data)[0] = nb1; - ((int32_t *) c->data)[1] = nb2; - ((int32_t *) c->data)[2] = nb3; - ((int32_t *) c->data)[3] = offset; - ((int32_t *) c->data)[4] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ACC; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -5214,7 +5236,7 @@ struct ggml_tensor * ggml_acc_inplace( // ggml_sub -struct ggml_tensor * ggml_sub_impl( +static struct ggml_tensor * ggml_sub_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5253,7 +5275,7 @@ struct ggml_tensor * ggml_sub_inplace( // ggml_mul -struct ggml_tensor * ggml_mul_impl( +static struct ggml_tensor * ggml_mul_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5300,7 +5322,7 @@ struct ggml_tensor * ggml_mul_inplace( // ggml_div -struct ggml_tensor * ggml_div_impl( +static struct ggml_tensor * ggml_div_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5343,7 +5365,7 @@ struct ggml_tensor * ggml_div_inplace( // ggml_sqr -struct ggml_tensor * ggml_sqr_impl( +static struct ggml_tensor * ggml_sqr_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -5358,7 +5380,6 @@ struct ggml_tensor * ggml_sqr_impl( result->op = GGML_OP_SQR; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5377,7 +5398,7 @@ struct ggml_tensor * ggml_sqr_inplace( // ggml_sqrt -struct ggml_tensor * ggml_sqrt_impl( +static struct ggml_tensor * ggml_sqrt_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -5392,7 +5413,6 @@ struct ggml_tensor * ggml_sqrt_impl( result->op = GGML_OP_SQRT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5412,7 +5432,7 @@ struct ggml_tensor * ggml_sqrt_inplace( // ggml_log -struct ggml_tensor * ggml_log_impl( +static struct ggml_tensor * ggml_log_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -5427,7 +5447,6 @@ struct ggml_tensor * ggml_log_impl( result->op = GGML_OP_LOG; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5460,7 +5479,6 @@ struct ggml_tensor * ggml_sum( result->op = GGML_OP_SUM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5487,7 +5505,6 @@ struct ggml_tensor * ggml_sum_rows( result->op = GGML_OP_SUM_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5510,7 +5527,6 @@ struct ggml_tensor * ggml_mean( result->op = GGML_OP_MEAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5534,7 +5550,6 @@ struct ggml_tensor * ggml_argmax( result->op = GGML_OP_ARGMAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5597,343 +5612,142 @@ struct ggml_tensor * ggml_repeat_back( // ggml_abs -struct ggml_tensor * ggml_abs_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ABS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_abs( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_abs_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_ABS); } struct ggml_tensor * ggml_abs_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_abs_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS); } - // ggml_sgn -struct ggml_tensor * ggml_sgn_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SGN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_sgn( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_sgn_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_SGN); } struct ggml_tensor * ggml_sgn_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_sgn_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN); } // ggml_neg -struct ggml_tensor * ggml_neg_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_NEG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_neg( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_neg_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_NEG); } struct ggml_tensor * ggml_neg_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_neg_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG); } // ggml_step -struct ggml_tensor * ggml_step_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_STEP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_step( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_step_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_STEP); } struct ggml_tensor * ggml_step_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_step_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP); } // ggml_tanh -struct ggml_tensor * ggml_tanh_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_TANH; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_tanh( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_tanh_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_TANH); } struct ggml_tensor * ggml_tanh_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_tanh_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH); } // ggml_elu -struct ggml_tensor * ggml_elu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_elu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_elu_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_ELU); } struct ggml_tensor * ggml_elu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_elu_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); } // ggml_relu -struct ggml_tensor * ggml_relu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_RELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_relu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_relu_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU); } struct ggml_tensor * ggml_relu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_relu_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } // ggml_gelu -struct ggml_tensor * ggml_gelu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_GELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_gelu_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU); } struct ggml_tensor * ggml_gelu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_gelu_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); } // ggml_gelu_quick -struct ggml_tensor * ggml_gelu_quick_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_GELU_QUICK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_gelu_quick( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_gelu_quick_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK); } struct ggml_tensor * ggml_gelu_quick_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_gelu_quick_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK); } // ggml_silu -struct ggml_tensor * ggml_silu_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SILU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - struct ggml_tensor * ggml_silu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_silu_impl(ctx, a, false); + return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); } struct ggml_tensor * ggml_silu_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_silu_impl(ctx, a, true); + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } // ggml_silu_back @@ -5961,7 +5775,7 @@ struct ggml_tensor * ggml_silu_back( // ggml_norm -struct ggml_tensor * ggml_norm_impl( +static struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -5974,10 +5788,11 @@ struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + // TODO: maybe store epsilon here? + result->op = GGML_OP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } @@ -5994,9 +5809,10 @@ struct ggml_tensor * ggml_norm_inplace( return ggml_norm_impl(ctx, a, true); } -struct ggml_tensor * ggml_rms_norm_impl( +static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -6006,24 +5822,27 @@ struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &eps, sizeof(eps)); + result->op = GGML_OP_RMS_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); } struct ggml_tensor * ggml_rms_norm_back( @@ -6102,7 +5921,7 @@ struct ggml_tensor * ggml_out_prod( // ggml_scale -struct ggml_tensor * ggml_scale_impl( +static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6142,7 +5961,7 @@ struct ggml_tensor * ggml_scale_inplace( // ggml_set -struct ggml_tensor * ggml_set_impl( +static struct ggml_tensor * ggml_set_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6162,23 +5981,13 @@ struct ggml_tensor * ggml_set_impl( // make a view of the destination struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5); - - (( int32_t * ) c->data)[0] = nb1; - (( int32_t * ) c->data)[1] = nb2; - (( int32_t * ) c->data)[2] = nb3; - (( int32_t * ) c->data)[3] = offset; - (( int32_t * ) c->data)[4] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_SET; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -6242,7 +6051,7 @@ struct ggml_tensor * ggml_set_2d_inplace( // ggml_cpy -struct ggml_tensor * ggml_cpy_impl( +static struct ggml_tensor * ggml_cpy_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6287,7 +6096,7 @@ struct ggml_tensor * ggml_cpy_inplace( // ggml_cont -struct ggml_tensor * ggml_cont_impl( +static struct ggml_tensor * ggml_cont_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -6303,7 +6112,6 @@ struct ggml_tensor * ggml_cont_impl( result->op = GGML_OP_CONT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6347,7 +6155,6 @@ struct ggml_tensor * ggml_reshape( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6372,7 +6179,6 @@ struct ggml_tensor * ggml_reshape_1d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6398,7 +6204,6 @@ struct ggml_tensor * ggml_reshape_2d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6425,7 +6230,6 @@ struct ggml_tensor * ggml_reshape_3d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6454,7 +6258,6 @@ struct ggml_tensor * ggml_reshape_4d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6476,19 +6279,11 @@ struct ggml_tensor * ggml_view_1d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6514,13 +6309,7 @@ struct ggml_tensor * ggml_view_2d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = result->nb[1]*ne1; @@ -6529,8 +6318,6 @@ struct ggml_tensor * ggml_view_2d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6558,13 +6345,7 @@ struct ggml_tensor * ggml_view_3d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6573,8 +6354,6 @@ struct ggml_tensor * ggml_view_3d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6604,13 +6383,7 @@ struct ggml_tensor * ggml_view_4d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6619,8 +6392,6 @@ struct ggml_tensor * ggml_view_4d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6681,22 +6452,9 @@ struct ggml_tensor * ggml_permute( result->op = GGML_OP_PERMUTE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - if (is_node) { - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); - - ((int32_t *) b->data)[0] = axis0; - ((int32_t *) b->data)[1] = axis1; - ((int32_t *) b->data)[2] = axis2; - ((int32_t *) b->data)[3] = axis3; - - ggml_scratch_load(ctx); - - result->src[2] = b; - } + int32_t params[] = { axis0, axis1, axis2, axis3 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); return result; } @@ -6724,7 +6482,6 @@ struct ggml_tensor * ggml_transpose( result->op = GGML_OP_TRANSPOSE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6802,7 +6559,6 @@ struct ggml_tensor * ggml_diag( result->op = GGML_OP_DIAG; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6810,7 +6566,7 @@ struct ggml_tensor * ggml_diag( // ggml_diag_mask_inf -struct ggml_tensor * ggml_diag_mask_inf_impl( +static struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, @@ -6823,19 +6579,12 @@ struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, inplace ? 1 : 0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_DIAG_MASK_INF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6857,7 +6606,7 @@ struct ggml_tensor * ggml_diag_mask_inf_inplace( // ggml_diag_mask_zero -struct ggml_tensor * ggml_diag_mask_zero_impl( +static struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, @@ -6870,20 +6619,12 @@ struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(b, "n_past, inplace"); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, inplace ? 1 : 0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_DIAG_MASK_ZERO; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6904,7 +6645,7 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace( // ggml_soft_max -struct ggml_tensor * ggml_soft_max_impl( +static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, bool inplace) { @@ -6919,7 +6660,6 @@ struct ggml_tensor * ggml_soft_max_impl( result->op = GGML_OP_SOFT_MAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6939,7 +6679,7 @@ struct ggml_tensor * ggml_soft_max_inplace( // ggml_soft_max_back -struct ggml_tensor * ggml_soft_max_back_impl( +static struct ggml_tensor * ggml_soft_max_back_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6976,15 +6716,15 @@ struct ggml_tensor * ggml_soft_max_back_inplace( // ggml_rope -struct ggml_tensor * ggml_rope_impl( +static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_dims, int mode, + int n_ctx, float freq_base, float freq_scale, - int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6995,23 +6735,14 @@ struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - ((int32_t *) b->data)[3] = n_ctx; - memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float)); - memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float)); - - ggml_scratch_load(ctx); + int32_t params[6] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7023,7 +6754,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); } struct ggml_tensor * ggml_rope_inplace( @@ -7033,7 +6764,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -7042,10 +6773,10 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true); + float freq_scale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); } // ggml_rope_back @@ -7055,7 +6786,8 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { + int mode, + int n_ctx) { GGML_ASSERT(n_past >= 0); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -7067,21 +6799,12 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ggml_set_name(b, "n_past, n_dims, mode"); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, n_dims, mode, n_ctx }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7106,21 +6829,13 @@ struct ggml_tensor * ggml_alibi( //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_head; - GGML_ASSERT(sizeof(float) == sizeof(int32_t)); - (((float *) b->data)[2]) = bias_max; - - ggml_scratch_load(ctx); + int32_t op_params[3] = { n_past, n_head }; + memcpy(op_params + 2, &bias_max, sizeof(float)); + ggml_set_op_params(result, &op_params, sizeof(op_params)); result->op = GGML_OP_ALIBI; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7142,19 +6857,12 @@ struct ggml_tensor * ggml_clamp( // TODO: when implement backward, fix this: struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2); - - ((float *) b->data)[0] = min; - ((float *) b->data)[1] = max; - - ggml_scratch_load(ctx); + float params[] = { min, max }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CLAMP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7187,18 +6895,13 @@ GGML_API struct ggml_tensor * ggml_conv_1d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ((int32_t*)c->data)[0] = s0; - ((int32_t*)c->data)[1] = p0; - ((int32_t*)c->data)[2] = d0; - ggml_scratch_load(ctx); + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CONV_1D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -7231,21 +6934,13 @@ struct ggml_tensor* ggml_conv_2d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); - ((int32_t*)c->data)[0] = s0; - ((int32_t*)c->data)[1] = s1; - ((int32_t*)c->data)[2] = p0; - ((int32_t*)c->data)[3] = p1; - ((int32_t*)c->data)[4] = d0; - ((int32_t*)c->data)[5] = d1; - ggml_scratch_load(ctx); + int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CONV_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; @@ -7269,7 +6964,7 @@ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { return (ins + 2 * p - ks) / s + 1; } -// ggml_pool_2d +// ggml_pool_1d struct ggml_tensor* ggml_pool_1d( struct ggml_context * ctx, @@ -7292,18 +6987,12 @@ struct ggml_tensor* ggml_pool_1d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); - ((int32_t*)c->data)[0] = op; - ((int32_t*)c->data)[1] = k0; - ((int32_t*)c->data)[2] = s0; - ((int32_t*)c->data)[3] = p0; - ggml_scratch_load(ctx); + int32_t params[] = { op, k0, s0, p0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_POOL_1D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = c; return result; } @@ -7335,21 +7024,12 @@ struct ggml_tensor* ggml_pool_2d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7); - ((int32_t*)c->data)[0] = op; - ((int32_t*)c->data)[1] = k0; - ((int32_t*)c->data)[2] = k1; - ((int32_t*)c->data)[3] = s0; - ((int32_t*)c->data)[4] = s1; - ((int32_t*)c->data)[5] = p0; - ((int32_t*)c->data)[6] = p1; - ggml_scratch_load(ctx); + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = c; return result; } @@ -7372,14 +7052,16 @@ struct ggml_tensor * ggml_flash_attn( } //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne); + + int32_t t = masked ? 1 : 0; + ggml_set_op_params(result, &t, sizeof(t)); result->op = GGML_OP_FLASH_ATTN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; result->src[1] = k; result->src[2] = v; - result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0); return result; } @@ -7403,7 +7085,7 @@ struct ggml_tensor * ggml_flash_ff( } //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne); result->op = GGML_OP_FLASH_FF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7469,13 +7151,15 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t masked_i = masked ? 1 : 0; + ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + result->op = GGML_OP_FLASH_ATTN_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; result->src[1] = k; result->src[2] = v; result->src[3] = d; - result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0); return result; } @@ -7508,21 +7192,12 @@ struct ggml_tensor * ggml_win_part( struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - - ((int32_t *) b->data)[0] = npx; - ((int32_t *) b->data)[1] = npy; - ((int32_t *) b->data)[2] = w; - - ggml_scratch_load(ctx); + int32_t params[] = { npx, npy, w }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_WIN_PART; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = b; return result; } @@ -7547,26 +7222,57 @@ struct ggml_tensor * ggml_win_unpart( const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - - ((int32_t *) b->data)[0] = w; - - ggml_scratch_load(ctx); + int32_t params[] = { w }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_WIN_UNPART; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = b; return result; } +// gmml_unary + +static struct ggml_tensor * ggml_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, false); +} + +struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, true); +} + // ggml_map_unary -struct ggml_tensor * ggml_map_unary_impl_f32( +static struct ggml_tensor * ggml_map_unary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, const ggml_unary_op_f32_t fun, @@ -7577,19 +7283,13 @@ struct ggml_tensor * ggml_map_unary_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_UNARY; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[2] = addr_tensor; return result; } @@ -7610,7 +7310,7 @@ struct ggml_tensor * ggml_map_unary_inplace_f32( // ggml_map_binary -struct ggml_tensor * ggml_map_binary_impl_f32( +static struct ggml_tensor * ggml_map_binary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -7624,20 +7324,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_BINARY; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; return result; } @@ -7660,7 +7354,7 @@ struct ggml_tensor * ggml_map_binary_inplace_f32( // ggml_map_custom1 -struct ggml_tensor * ggml_map_custom1_impl_f32( +static struct ggml_tensor * ggml_map_custom1_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, const ggml_custom1_op_f32_t fun, @@ -7671,19 +7365,13 @@ struct ggml_tensor * ggml_map_custom1_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[2] = addr_tensor; return result; } @@ -7704,7 +7392,7 @@ struct ggml_tensor * ggml_map_custom1_inplace_f32( // ggml_map_custom2 -struct ggml_tensor * ggml_map_custom2_impl_f32( +static struct ggml_tensor * ggml_map_custom2_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -7716,20 +7404,14 @@ struct ggml_tensor * ggml_map_custom2_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM2; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; return result; } @@ -7752,7 +7434,7 @@ struct ggml_tensor * ggml_map_custom2_inplace_f32( // ggml_map_custom3 -struct ggml_tensor * ggml_map_custom3_impl_f32( +static struct ggml_tensor * ggml_map_custom3_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -7765,21 +7447,15 @@ struct ggml_tensor * ggml_map_custom3_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM3; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; - result->src[3] = c; + result->src[2] = c; return result; } @@ -9007,21 +8683,17 @@ static void ggml_compute_forward_acc_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(opt0->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(opt0) == 5); - // view src0 and dst with these strides and data offset inbytes during acc // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) opt0->data)[0]; - size_t nb2 = ((int32_t *) opt0->data)[1]; - size_t nb3 = ((int32_t *) opt0->data)[2]; - size_t offset = ((int32_t *) opt0->data)[3]; - bool inplace = (bool) ((int32_t *) opt0->data)[4]; + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace && (params->type == GGML_TASK_INIT)) { // memcpy needs to be synchronized across threads to avoid race conditions. @@ -9090,13 +8762,12 @@ static void ggml_compute_forward_acc( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_acc_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -9528,7 +9199,7 @@ static void ggml_compute_forward_sum_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_ggf(ne00, + ggml_vec_sum_f32_ggf(ne00, &row_sum, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); sum += row_sum; @@ -9538,6 +9209,38 @@ static void ggml_compute_forward_sum_f32( ((float *) dst->data)[0] = sum; } +static void ggml_compute_forward_sum_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); +} + static void ggml_compute_forward_sum( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -9547,6 +9250,10 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f32(params, src0, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sum_f16(params, src0, dst); + } break; default: { GGML_ASSERT(false); @@ -10142,8 +9849,8 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10201,8 +9908,8 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10260,8 +9967,8 @@ static void ggml_compute_forward_silu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10313,7 +10020,6 @@ static void ggml_compute_forward_silu( } } - // ggml_compute_forward_silu_back static void ggml_compute_forward_silu_back_f32( @@ -10321,9 +10027,9 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0, const struct ggml_tensor * grad, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(grad)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); @@ -10463,7 +10169,8 @@ static void ggml_compute_forward_rms_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -11116,21 +10823,17 @@ static void ggml_compute_forward_set_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(opt0->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(opt0) == 5); - // view src0 and dst with these strides and data offset inbytes during set // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) opt0->data)[0]; - size_t nb2 = ((int32_t *) opt0->data)[1]; - size_t nb3 = ((int32_t *) opt0->data)[2]; - size_t offset = ((int32_t *) opt0->data)[3]; - bool inplace = (bool) ((int32_t *) opt0->data)[4]; + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace && (params->type == GGML_TASK_INIT)) { // memcpy needs to be synchronized across threads to avoid race conditions. @@ -11190,13 +10893,12 @@ static void ggml_compute_forward_set( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_set_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_set_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -11592,17 +11294,14 @@ static void ggml_compute_forward_diag( static void ggml_compute_forward_diag_mask_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst, const float value) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 2); const int ith = params->ith; const int nth = params->nth; - const int n_past = ((int32_t *) src1->data)[0]; - const bool inplace = (bool)((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = (bool)((int32_t *) dst->op_params)[1]; GGML_ASSERT(n_past >= 0); @@ -11645,12 +11344,11 @@ static void ggml_compute_forward_diag_mask_f32( static void ggml_compute_forward_diag_mask_inf( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY); + ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY); } break; default: { @@ -11662,12 +11360,11 @@ static void ggml_compute_forward_diag_mask_inf( static void ggml_compute_forward_diag_mask_zero( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0); + ggml_compute_forward_diag_mask_f32(params, src0, dst, 0); } break; default: { @@ -11865,20 +11562,17 @@ static void ggml_compute_forward_soft_max_back( static void ggml_compute_forward_alibi_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); assert(n_past >= 0); @@ -11931,20 +11625,17 @@ static void ggml_compute_forward_alibi_f32( static void ggml_compute_forward_alibi_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); assert(n_past >= 0); @@ -11997,16 +11688,15 @@ static void ggml_compute_forward_alibi_f16( static void ggml_compute_forward_alibi( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_alibi_f16(params, src0, src1, dst); + ggml_compute_forward_alibi_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_alibi_f32(params, src0, src1, dst); + ggml_compute_forward_alibi_f32(params, src0, dst); } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -12036,19 +11726,17 @@ static void ggml_compute_forward_alibi( static void ggml_compute_forward_clamp_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_nelements(src1) == 2); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const float min = ((float *) src1->data)[0]; - const float max = ((float *) src1->data)[1]; + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -12078,12 +11766,11 @@ static void ggml_compute_forward_clamp_f32( static void ggml_compute_forward_clamp( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_clamp_f32(params, src0, src1, dst); + ggml_compute_forward_clamp_f32(params, src0, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -12113,10 +11800,7 @@ static void ggml_compute_forward_clamp( static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12125,12 +11809,12 @@ static void ggml_compute_forward_rope_f32( float freq_base; float freq_scale; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); assert(n_past >= 0); @@ -12245,10 +11929,7 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12257,12 +11938,12 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); assert(n_past >= 0); @@ -12377,16 +12058,15 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, src1, dst); + ggml_compute_forward_rope_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, src1, dst); + ggml_compute_forward_rope_f32(params, src0, dst); } break; default: { @@ -12400,10 +12080,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12413,9 +12090,9 @@ static void ggml_compute_forward_rope_back_f32( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; assert(n_past >= 0); @@ -12499,10 +12176,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12512,9 +12186,9 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; assert(n_past >= 0); @@ -12598,16 +12272,15 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, src1, dst); + ggml_compute_forward_rope_back_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, src1, dst); + ggml_compute_forward_rope_back_f32(params, src0, dst); } break; default: { @@ -12804,7 +12477,7 @@ static void ggml_compute_forward_conv_1d_s1_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { @@ -13007,7 +12680,7 @@ static void ggml_compute_forward_conv_1d_s2_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { @@ -13027,14 +12700,13 @@ static void ggml_compute_forward_conv_1d_s2_ph( // ggml_compute_forward_conv_1d static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - const int32_t s0 = ((const int32_t*)(opt0->data))[0]; - const int32_t p0 = ((const int32_t*)(opt0->data))[1]; - const int32_t d0 = ((const int32_t*)(opt0->data))[2]; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; GGML_ASSERT(d0 == 1); // dilation not supported GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported if (s0 == 1) { @@ -13052,7 +12724,6 @@ static void ggml_compute_forward_conv_2d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13072,12 +12743,12 @@ static void ggml_compute_forward_conv_2d_f16_f32( // size of the convolution row - the kernel size unrolled across all channels const int ew0 = nk0*nk1*ne02; - const int32_t s0 = ((const int32_t*)(opt0->data))[0]; - const int32_t s1 = ((const int32_t*)(opt0->data))[1]; - const int32_t p0 = ((const int32_t*)(opt0->data))[2]; - const int32_t p1 = ((const int32_t*)(opt0->data))[3]; - const int32_t d0 = ((const int32_t*)(opt0->data))[4]; - const int32_t d1 = ((const int32_t*)(opt0->data))[5]; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); @@ -13149,17 +12820,15 @@ static void ggml_compute_forward_conv_2d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst - ) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst); + //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); GGML_ASSERT(false); } break; default: @@ -13224,12 +12893,11 @@ static void ggml_compute_forward_pool_1d_sk_p0( // ggml_compute_forward_pool_1d static void ggml_compute_forward_pool_1d( - const struct ggml_compute_params* params, - const struct ggml_tensor* src0, - const struct ggml_tensor* opt0, - struct ggml_tensor* dst) { - GGML_ASSERT(opt0->ne[0] == 4); - const int* opts = (const int*)opt0->data; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t* opts = (const int32_t*)dst->op_params; enum ggml_op_pool op = opts[0]; const int k0 = opts[1]; const int s0 = opts[2]; @@ -13243,12 +12911,12 @@ static void ggml_compute_forward_pool_1d( // ggml_compute_forward_pool_2d_sk_p0 static void ggml_compute_forward_pool_2d_sk_p0( - const struct ggml_compute_params * params, - const enum ggml_op_pool op, - const struct ggml_tensor * src, - const int k0, - const int k1, - struct ggml_tensor * dst) { + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const struct ggml_tensor * src, + const int k0, + const int k1, + struct ggml_tensor * dst) { assert(src->type == GGML_TYPE_F32); assert(params->ith == 0); @@ -13308,12 +12976,11 @@ static void ggml_compute_forward_pool_2d_sk_p0( // ggml_compute_forward_pool_2d static void ggml_compute_forward_pool_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - GGML_ASSERT(opt0->ne[0] == 7); - const int* opts = (const int*)opt0->data; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; enum ggml_op_pool op = opts[0]; const int k0 = opts[1]; const int k1 = opts[2]; @@ -13338,7 +13005,7 @@ static void ggml_compute_forward_flash_attn_f32( const struct ggml_tensor * k, const struct ggml_tensor * v, const bool masked, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); @@ -13516,7 +13183,7 @@ static void ggml_compute_forward_flash_attn_f16( const struct ggml_tensor * k, const struct ggml_tensor * v, const bool masked, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); @@ -14281,7 +13948,6 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_win_part_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -14290,9 +13956,9 @@ static void ggml_compute_forward_win_part_f32( GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; - const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; - const int32_t w = ((const int32_t *)(opt0->data))[2]; + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; assert(ne00 == ne0); assert(ne3 == nep0*nep1); @@ -14326,12 +13992,11 @@ static void ggml_compute_forward_win_part_f32( static void ggml_compute_forward_win_part( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_win_part_f32(params, src0, opt0, dst); + ggml_compute_forward_win_part_f32(params, src0, dst); } break; default: { @@ -14345,7 +14010,6 @@ static void ggml_compute_forward_win_part( static void ggml_compute_forward_win_unpart_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -14354,7 +14018,7 @@ static void ggml_compute_forward_win_unpart_f32( GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - const int32_t w = ((const int32_t *)(opt0->data))[0]; + const int32_t w = ((const int32_t *)(dst->op_params))[0]; // padding const int px = (w - ne1%w)%w; @@ -14388,12 +14052,67 @@ static void ggml_compute_forward_win_unpart_f32( static void ggml_compute_forward_win_unpart( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst); + ggml_compute_forward_win_unpart_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +//gmml_compute_forward_unary + +static void ggml_compute_forward_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + const enum ggml_unary_op op = ggml_get_unary_op(dst); + + switch (op) { + case GGML_UNARY_OP_ABS: + { + ggml_compute_forward_abs(params, src0, dst); + } break; + case GGML_UNARY_OP_SGN: + { + ggml_compute_forward_sgn(params, src0, dst); + } break; + case GGML_UNARY_OP_NEG: + { + ggml_compute_forward_neg(params, src0, dst); + } break; + case GGML_UNARY_OP_STEP: + { + ggml_compute_forward_step(params, src0, dst); + } break; + case GGML_UNARY_OP_TANH: + { + ggml_compute_forward_tanh(params, src0, dst); + } break; + case GGML_UNARY_OP_ELU: + { + ggml_compute_forward_elu(params, src0, dst); + } break; + case GGML_UNARY_OP_RELU: + { + ggml_compute_forward_relu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU: + { + ggml_compute_forward_gelu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + ggml_compute_forward_gelu_quick(params, src0, dst); + } break; + case GGML_UNARY_OP_SILU: + { + ggml_compute_forward_silu(params, src0, dst); } break; default: { @@ -14912,7 +14631,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ACC: { - ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_SUB: { @@ -14962,46 +14681,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_repeat_back(params, tensor->src[0], tensor); } break; - case GGML_OP_ABS: - { - ggml_compute_forward_abs(params, tensor->src[0], tensor); - } break; - case GGML_OP_SGN: - { - ggml_compute_forward_sgn(params, tensor->src[0], tensor); - } break; - case GGML_OP_NEG: - { - ggml_compute_forward_neg(params, tensor->src[0], tensor); - } break; - case GGML_OP_STEP: - { - ggml_compute_forward_step(params, tensor->src[0], tensor); - } break; - case GGML_OP_TANH: - { - ggml_compute_forward_tanh(params, tensor->src[0], tensor); - } break; - case GGML_OP_ELU: - { - ggml_compute_forward_elu(params, tensor->src[0], tensor); - } break; - case GGML_OP_RELU: - { - ggml_compute_forward_relu(params, tensor->src[0], tensor); - } break; - case GGML_OP_GELU: - { - ggml_compute_forward_gelu(params, tensor->src[0], tensor); - } break; - case GGML_OP_GELU_QUICK: - { - ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor); - } break; - case GGML_OP_SILU: - { - ggml_compute_forward_silu(params, tensor->src[0], tensor); - } break; case GGML_OP_SILU_BACK: { ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor); @@ -15032,7 +14711,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SET: { - ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CPY: { @@ -15072,11 +14751,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_DIAG_MASK_INF: { - ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor); } break; case GGML_OP_DIAG_MASK_ZERO: { - ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor); } break; case GGML_OP_SOFT_MAX: { @@ -15088,39 +14767,39 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor); } break; case GGML_OP_ALIBI: { - ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_alibi(params, tensor->src[0], tensor); } break; case GGML_OP_CLAMP: { - ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_clamp(params, tensor->src[0], tensor); } break; case GGML_OP_CONV_1D: { - ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CONV_2D: { - ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_POOL_1D: { - ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); } break; case GGML_OP_POOL_2D: { - ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); } break; case GGML_OP_FLASH_ATTN: { - const int32_t t = ggml_get_i32_1d(tensor->src[3], 0); + const int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); @@ -15131,47 +14810,56 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN_BACK: { - int32_t t = ggml_get_i32_1d(tensor->src[4], 0); + int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); } break; case GGML_OP_WIN_PART: { - ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor); + ggml_compute_forward_win_part(params, tensor->src[0], tensor); } break; case GGML_OP_WIN_UNPART: { - ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor); + ggml_compute_forward_win_unpart(params, tensor->src[0], tensor); + } break; + case GGML_OP_UNARY: + { + ggml_compute_forward_unary(params, tensor->src[0], tensor); } break; case GGML_OP_MAP_UNARY: { - const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data); + ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun); } break; case GGML_OP_MAP_BINARY: { - const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data); + ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun); } break; case GGML_OP_MAP_CUSTOM1: { - const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data); + ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun); } break; case GGML_OP_MAP_CUSTOM2: { - const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data); + ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun); } break; case GGML_OP_MAP_CUSTOM3: { - const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data); - ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun); + ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun); } break; case GGML_OP_CROSS_ENTROPY_LOSS: @@ -15235,12 +14923,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } if (src1->grad) { - GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5); - GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32); - const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0]; - const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1]; - const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2]; - const size_t offset = (( int32_t * ) tensor->src[2]->data)[3]; + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, tensor->grad, @@ -15389,73 +15075,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor inplace); } } break; - case GGML_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); - } - } break; - case GGML_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_OP_TANH: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_ELU: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - inplace); - } - } break; - case GGML_OP_GELU: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_GELU_QUICK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_impl(ctx, - src0->grad, - ggml_silu_back(ctx, src0, tensor->grad), - inplace); - } - } break; case GGML_OP_SILU_BACK: { GGML_ASSERT(false); // TODO: not implemented @@ -15548,12 +15167,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_SET: { - GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5); - GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32); - const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0]; - const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1]; - const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2]; - const size_t offset = (( int32_t * ) tensor->src[2]->data)[3]; + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; struct ggml_tensor * tensor_grad_view = NULL; @@ -15630,8 +15247,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { size_t offset; - GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2])); - memcpy(&offset, tensor->src[2]->data, sizeof(offset)); + memcpy(&offset, tensor->op_params, sizeof(offset)); size_t nb1 = tensor->nb[1]; size_t nb2 = tensor->nb[2]; @@ -15658,7 +15274,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - int32_t * axes = (int32_t *) tensor->src[2]->data; + int32_t * axes = (int32_t *) tensor->op_params; int axis0 = axes[0] & 0x3; int axis1 = axes[1] & 0x3; int axis2 = axes[2] & 0x3; @@ -15714,33 +15330,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_DIAG_MASK_ZERO: { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_SOFT_MAX: { @@ -15761,33 +15367,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 6); - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, n_past, n_dims, - mode), + mode, + n_ctx), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_ROPE_BACK: { if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope(ctx, @@ -15798,9 +15399,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor n_ctx), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_ALIBI: { @@ -15830,7 +15428,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_i32_1d(tensor->src[3], 0); + int32_t t = ggml_get_op_params_i32(tensor, 0); GGML_ASSERT(t == 0 || t == 1); bool masked = t != 0; flash_grad = @@ -15993,6 +15591,80 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: + case GGML_OP_UNARY: + { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_UNARY_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_UNARY_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_TANH: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_ELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_UNARY_OP_GELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_SILU: + { + // necessary for llama + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_silu_back(ctx, src0, tensor->grad), + inplace); + } + } break; + default: + GGML_ASSERT(false); + } + } break; case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1: @@ -16028,6 +15700,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } +static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // hash table is full + GGML_ASSERT(false); + } + } + + if (hash_table[i] == p) { + return true; + } + + // insert + hash_table[i] = p; + return false; +} + static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { if (node->grad == NULL) { // this usually happens when we generate intermediate nodes from constants in the backward pass @@ -16038,16 +15738,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } // check if already visited - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return; - } - } - - for (int i = 0; i < cgraph->n_leafs; i++) { - if (cgraph->leafs[i] == node) { - return; - } + if (hash_insert(cgraph->visited_hash_table, node)) { + return; } for (int i = 0; i < GGML_MAX_SRC; ++i) { @@ -16110,6 +15802,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { /*.nodes =*/ { NULL }, /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -16151,13 +15844,42 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg if (node->is_param) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_impl(&result, node->grad, true); + ggml_build_forward_expand(&result, node->grad); } } return result; } +struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); + struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + *cgraph = (struct ggml_cgraph) { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + return cgraph; +} + +struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_cgraph * cgraph = ggml_new_graph(ctx); + ggml_build_forward_impl(cgraph, tensor, false); + return cgraph; +} + +size_t ggml_graph_overhead(void) { + return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN); +} + // // thread data // @@ -16223,7 +15945,7 @@ typedef pthread_t ggml_thread_t; // Android's libc implementation "bionic" does not support setting affinity #if defined(__linux__) && !defined(__BIONIC__) -void set_numa_thread_affinity(int thread_n, int n_threads) { +static void set_numa_thread_affinity(int thread_n, int n_threads) { if (!ggml_is_numa()) { return; } @@ -16248,7 +15970,7 @@ void set_numa_thread_affinity(int thread_n, int n_threads) { CPU_FREE(cpus); } -void clear_numa_thread_affinity(void) { +static void clear_numa_thread_affinity(void) { if (!ggml_is_numa()) { return; } @@ -16272,8 +15994,8 @@ void clear_numa_thread_affinity(void) { #else // TODO: Windows etc. // (the linux implementation may also work on BSD, someone should test) -void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } -void clear_numa_thread_affinity(void) {} +static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } +static void clear_numa_thread_affinity(void) {} #endif struct ggml_compute_state_shared { @@ -16485,21 +16207,34 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_ARGMAX: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: - case GGML_OP_ABS: - case GGML_OP_SGN: - case GGML_OP_NEG: - case GGML_OP_STEP: - case GGML_OP_TANH: - case GGML_OP_ELU: - case GGML_OP_RELU: - { + { n_tasks = 1; } break; - case GGML_OP_MUL: - case GGML_OP_GELU: - case GGML_OP_GELU_QUICK: - case GGML_OP_SILU: + + case GGML_OP_UNARY: + { + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + { + n_tasks = 1; + } break; + + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + } + } break; case GGML_OP_SILU_BACK: + case GGML_OP_MUL: case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: @@ -16564,10 +16299,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG: - case GGML_OP_DIAG_MASK_ZERO: { n_tasks = 1; } break; + case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -16860,10 +16595,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); - struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size); - GGML_ASSERT(buf); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); - cplan.work_data = buf->data; + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; ggml_graph_compute(cgraph, &cplan); } @@ -17014,7 +16748,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); // dump the data // TODO: pad this to 32 byte boundary @@ -17047,7 +16782,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); // output the op arguments { @@ -17228,7 +16964,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** tensor->op = (enum ggml_op) op; - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; tensor->data = (void *) ptr; @@ -17273,7 +17010,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** nb[j] = nb_cur; } - const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); @@ -17310,8 +17048,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** { tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - uint64_t offs; - memcpy(&offs, args[2]->data, sizeof(offs)); + size_t offs; + memcpy(&offs, ptr_op_params, sizeof(offs)); tensor->data = ((char *) tensor->data) + offs; } break; @@ -17331,7 +17069,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** } break; } - memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); for (int j = 0; j < GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j]; @@ -17365,7 +17104,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", i, node->ne[0], node->ne[1], node->ne[2], - GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, (double) node->perf_cycles / (double) ggml_cycles_per_ms(), (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, (double) node->perf_time_us / 1000.0, @@ -17379,7 +17118,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", i, node->ne[0], node->ne[1], - GGML_OP_NAME[node->op]); + ggml_op_name(node->op)); } for (int i = 0; i < GGML_OP_COUNT; i++) { @@ -17387,7 +17126,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { continue; } - GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0); + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); } GGML_PRINT("========================================\n"); @@ -17481,13 +17220,13 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph } if (node->n_dims == 2) { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]); + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op)); } else { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]); + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); } if (node->grad) { - fprintf(fp, " | %s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]); + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); } else { fprintf(fp, "\"; ]\n"); } diff --git a/llama/ggml.h b/llama/ggml.h index d8d769c8..3069ae58 100644 --- a/llama/ggml.h +++ b/llama/ggml.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -225,6 +225,7 @@ #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 #define GGML_MAX_NAME 48 +#define GGML_MAX_OP_PARAMS 32 #define GGML_DEFAULT_N_THREADS 4 @@ -233,6 +234,7 @@ #define GGML_UNUSED(x) (void)(x) +#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #define GGML_ASSERT(x) \ do { \ @@ -355,16 +357,6 @@ extern "C" { GGML_OP_ARGMAX, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, - GGML_OP_ABS, - GGML_OP_SGN, - GGML_OP_NEG, - GGML_OP_STEP, - GGML_OP_TANH, - GGML_OP_ELU, - GGML_OP_RELU, - GGML_OP_GELU, - GGML_OP_GELU_QUICK, - GGML_OP_SILU, GGML_OP_SILU_BACK, GGML_OP_NORM, // normalize GGML_OP_RMS_NORM, @@ -403,6 +395,8 @@ extern "C" { GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, + GGML_OP_UNARY, + GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, @@ -416,6 +410,24 @@ extern "C" { GGML_OP_COUNT, }; + enum ggml_unary_op { + GGML_UNARY_OP_ABS, + GGML_UNARY_OP_SGN, + GGML_UNARY_OP_NEG, + GGML_UNARY_OP_STEP, + GGML_UNARY_OP_TANH, + GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_QUICK, + GGML_UNARY_OP_SILU, + }; + + enum ggml_object_type { + GGML_OBJECT_TENSOR, + GGML_OBJECT_GRAPH, + GGML_OBJECT_WORK_BUFFER + }; // ggml object struct ggml_object { @@ -424,7 +436,9 @@ extern "C" { struct ggml_object * next; - char padding[8]; + enum ggml_object_type type; + + char padding[4]; }; static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); @@ -444,6 +458,9 @@ extern "C" { // compute data enum ggml_op op; + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + bool is_param; struct ggml_tensor * grad; @@ -460,7 +477,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + char padding[4]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -481,6 +498,11 @@ extern "C" { void * abort_callback_data; }; + // next prime after GGML_MAX_NODES + // #define GGML_GRAPH_HASHTABLE_SIZE 4099 + // next prime after GGML_MAX_NODES * 2 (nodes + leafs) + #define GGML_GRAPH_HASHTABLE_SIZE 8273 + // computation graph struct ggml_cgraph { int n_nodes; @@ -490,12 +512,16 @@ extern "C" { struct ggml_tensor * grads[GGML_MAX_NODES]; struct ggml_tensor * leafs[GGML_MAX_NODES]; + void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE]; + // performance int perf_runs; int64_t perf_cycles; int64_t perf_time_us; }; + static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph); + // scratch buffer struct ggml_scratch { size_t offs; @@ -557,6 +583,7 @@ extern "C" { GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -580,6 +607,7 @@ extern "C" { GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); @@ -639,9 +667,11 @@ extern "C" { GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); - GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor); - GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name); - GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...); + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); // // operations on tensors with backpropagation @@ -651,6 +681,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_add( struct ggml_context * ctx, struct ggml_tensor * a, @@ -875,14 +910,17 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); // a - x // b - dy + // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -974,11 +1012,22 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // a -> b, in-place, return view(b) + GGML_API struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // make contiguous GGML_API struct ggml_tensor * ggml_cont( struct ggml_context * ctx, struct ggml_tensor * a); + // make contiguous, in-place + GGML_API struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view GGML_API struct ggml_tensor * ggml_reshape( @@ -1154,9 +1203,9 @@ extern "C" { int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx); + float freq_scale); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1165,7 +1214,8 @@ extern "C" { struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // alibi position embedding // in-place, returns view(a) @@ -1289,6 +1339,16 @@ extern "C" { typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + GGML_API struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_map_unary_f32( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1368,11 +1428,17 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); + // graph allocation in a context + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); + GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API size_t ggml_graph_overhead(void); + // ggml_graph_plan() has to be called before ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); diff --git a/llama/k_quants.c b/llama/k_quants.c index 5f83e879..e6bd7889 100644 --- a/llama/k_quants.c +++ b/llama/k_quants.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -1692,6 +1692,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) + summs; +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + #else float sumf = 0; @@ -2321,6 +2377,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = _mm256_set_m128i(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + #else int8_t aux8[QK_K]; @@ -2807,6 +2950,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) - summs; +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + #else uint8_t aux8[QK_K]; @@ -3321,10 +3518,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc); + + } + + *s = hsum_float_8(acc); + #else - - uint8_t aux8[QK_K]; + int8_t aux8[QK_K]; int16_t aux16[16]; float sums [8]; memset(sums, 0, 8*sizeof(float)); @@ -3334,7 +3587,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict hm = x[i].qh; const int8_t * restrict q8 = y[i].qs; - uint8_t * restrict a = aux8; + int8_t * restrict a = aux8; for (int l = 0; l < 32; ++l) { a[l+ 0] = q4[l] & 0xF; a[l+32] = q4[l] >> 4; @@ -3884,6 +4137,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc); + } + + *s = hsum_float_8(acc); + #else int8_t aux8[QK_K]; diff --git a/llama/k_quants.h b/llama/k_quants.h index 7623ffc5..177de2d3 100644 --- a/llama/k_quants.h +++ b/llama/k_quants.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/llama-util.h b/llama/llama-util.h index 0a934b52..0424ed65 100644 --- a/llama/llama-util.h +++ b/llama/llama-util.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * diff --git a/llama/llama.cpp b/llama/llama.cpp index a963047a..25c29e79 100644 --- a/llama/llama.cpp +++ b/llama/llama.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -93,6 +93,7 @@ enum e_model { MODEL_13B, MODEL_30B, MODEL_65B, + MODEL_70B, }; static const size_t kB = 1024; @@ -124,18 +125,18 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * } // -// memory sizes +// memory sizes (calculated for n_batch == 512) // static const std::map & MEM_REQ_SCRATCH0(int n_ctx) { static std::map k_sizes = { - /* empirical scaling, still a guess */ - { MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB }, + { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB }, + { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB }, + { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, + { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, + { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess + { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB }, }; return k_sizes; } @@ -143,38 +144,26 @@ static const std::map & MEM_REQ_SCRATCH0(int n_ctx) static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + { MODEL_3B, 128ull * MB }, + { MODEL_7B, 160ull * MB }, + { MODEL_13B, 192ull * MB }, + { MODEL_30B, 256ull * MB }, + { MODEL_65B, 384ull * MB }, // guess + { MODEL_70B, 304ull * MB }, }; return k_sizes; } -// 2*n_embd*n_ctx*n_layer*sizeof(float16) -static const std::map & MEM_REQ_KV_SELF() +// used to store the compute graph tensors + non-scratch data +static const std::map & MEM_REQ_EVAL() { static std::map k_sizes = { - { MODEL_3B, 682ull * MB }, - { MODEL_7B, 1026ull * MB }, - { MODEL_13B, 1608ull * MB }, - { MODEL_30B, 3124ull * MB }, - { MODEL_65B, 5120ull * MB }, - }; - return k_sizes; -} - -// this is mostly needed for temporary mul_mat buffers to dequantize the data -// not actually needed if BLAS is disabled -static const std::map & MEM_REQ_EVAL(int n_ctx) -{ - static std::map k_sizes = { - { MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB }, + { MODEL_3B, 8ull * MB }, + { MODEL_7B, 10ull * MB }, + { MODEL_13B, 12ull * MB }, + { MODEL_30B, 16ull * MB }, + { MODEL_65B, 24ull * MB }, // guess + { MODEL_70B, 24ull * MB }, }; return k_sizes; } @@ -189,6 +178,7 @@ static const std::map & VRAM_REQ_SCRATCH_BASE() { MODEL_13B, 640ull * kB }, { MODEL_30B, 768ull * kB }, { MODEL_65B, 1536ull * kB }, + { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced) }; return k_sizes; } @@ -203,19 +193,26 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() { MODEL_13B, 160ull }, { MODEL_30B, 208ull }, { MODEL_65B, 416ull }, + { MODEL_70B, 416ull }, // TODO (likely can be reduced) }; return k_sizes; } // default hparams (LLaMA 7B) struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 256; - uint32_t n_head = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_mult = 256; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + // LLaMAv2 + // TODO: load from model data hparams + float f_ffn_mult = 1.0f; + float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; @@ -223,7 +220,28 @@ struct llama_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; bool operator!=(const llama_hparams & other) const { - return static_cast(memcmp(this, &other, sizeof(llama_hparams))); + return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + } + + uint32_t n_gqa() const { + return n_head/n_head_kv; + } + + uint32_t n_embd_head() const { + return n_embd/n_head; + } + + uint32_t n_embd_gqa() const { + return n_embd/n_gqa(); + } + + size_t kv_size() const { + size_t result = 2ull; + result *= (size_t) n_embd_gqa(); + result *= (size_t) n_ctx; + result *= (size_t) n_layer; + result *= sizeof(ggml_fp16_t); + return result; } }; @@ -525,12 +543,16 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); - hparams.n_embd = file.read_u32(); - hparams.n_mult = file.read_u32(); - hparams.n_head = file.read_u32(); + hparams.n_embd = file.read_u32(); + hparams.n_mult = file.read_u32(); + hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); - hparams.ftype = (enum llama_ftype) file.read_u32(); + hparams.n_rot = file.read_u32(); + hparams.ftype = (enum llama_ftype) file.read_u32(); + + // LLaMAv2 + // TODO: read from header + hparams.n_head_kv = hparams.n_head; } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -829,7 +851,7 @@ static bool kv_cache_init( ggml_type wtype, int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd; + const int n_embd = hparams.n_embd_gqa(); const int n_layer = hparams.n_layer; const int64_t n_mem = n_layer*n_ctx; @@ -873,9 +895,11 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_gqa =*/ 1, + /*.rms_norm_eps =*/ LLAMA_DEFAULT_RMS_EPS, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, - /*.tensor_split =*/ {0}, + /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 10000.0f, /*.rope_freq_scale =*/ 1.0f, /*.progress_callback =*/ nullptr, @@ -992,6 +1016,7 @@ static const char *llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_30B: return "30B"; case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; default: LLAMA_ASSERT(false); } } @@ -1002,6 +1027,8 @@ static void llama_model_load_internal( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, + float rms_norm_eps, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1023,8 +1050,12 @@ static void llama_model_load_internal( model.hparams = ml->file_loader->hparams; model.n_gpu_layers = n_gpu_layers; llama_file_version file_version = ml->file_loader->file_version; + auto & hparams = model.hparams; + // TODO: read from file + hparams.f_rms_norm_eps = rms_norm_eps; + { switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1042,11 +1073,25 @@ static void llama_model_load_internal( hparams.n_ctx = n_ctx; + // LLaMAv2 + // TODO: temporary until GGUF + LLAMA_ASSERT(hparams.n_head % n_gqa == 0); + hparams.n_head_kv = hparams.n_head / n_gqa; + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model + } + hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; } - const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 + const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3; + const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw; + const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + //const uint32_t n_ff = 28672; { fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); @@ -1055,12 +1100,15 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } @@ -1095,7 +1143,7 @@ static void llama_model_load_internal( { model.buf.resize(ctx_size); if (use_mlock) { - model.mlock_buf.init(model.buf.addr); + model.mlock_buf.init (model.buf.addr); model.mlock_buf.grow_to(model.buf.size); } @@ -1130,9 +1178,10 @@ static void llama_model_load_internal( size_t vram_weights = 0; size_t vram_scratch = 0; { - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd_gqa = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; ml->ggml_ctx = ctx; @@ -1180,16 +1229,16 @@ static void llama_model_load_internal( layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); + layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); + layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1212,11 +1261,11 @@ static void llama_model_load_internal( mmapped_size - vram_weights + // weights in VRAM not in memory MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL(hparams.n_ctx).at(model.type); + MEM_REQ_EVAL().at(model.type); // this is the memory required by one llama_state const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); + scale*hparams.kv_size(); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -1257,7 +1306,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } if (n_gpu_layers > (int) hparams.n_layer + 2) { @@ -1265,7 +1314,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } #elif defined(GGML_USE_CLBLAST) @@ -1313,9 +1362,11 @@ static bool llama_model_load( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, + float rms_norm_eps, int n_gpu_layers, int main_gpu, - float * tensor_split, + const float * tensor_split, float rope_freq_base, float rope_freq_scale, bool low_vram, @@ -1326,7 +1377,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1370,16 +1421,23 @@ static bool llama_eval_internal( LLAMA_ASSERT(!!kv_self.ctx); - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - const int n_rot = hparams.n_embd/hparams.n_head; - const int n_gpu_layers = model.n_gpu_layers; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + + LLAMA_ASSERT(n_embd_head == hparams.n_rot); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const float rms_norm_eps = hparams.f_rms_norm_eps; + + const int n_gpu_layers = model.n_gpu_layers; auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; @@ -1392,7 +1450,7 @@ static bool llama_eval_internal( struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph gf = {}; + ggml_cgraph * gf = ggml_new_graph(ctx0); // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance @@ -1457,7 +1515,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_0"); @@ -1478,11 +1536,11 @@ static bool llama_eval_internal( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -1494,23 +1552,23 @@ static bool llama_eval_internal( offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } struct ggml_tensor * Q = @@ -1523,8 +1581,8 @@ static bool llama_eval_internal( struct ggml_tensor * K = ggml_permute(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), - n_embd/n_head, n_head, n_past + N), + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa), + n_embd_head, n_head_kv, n_past + N), 0, 2, 1, 3); offload_func_kq(K); ggml_set_name(K, "K"); @@ -1534,9 +1592,9 @@ static bool llama_eval_internal( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled = KQ / sqrt(n_embd_head) struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); - ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_scaled shape [n_past + N, N, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); @@ -1556,10 +1614,10 @@ static bool llama_eval_internal( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd/n_head, n_head, + n_past + N, n_embd_head, n_head_kv, n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + n_ctx*ggml_element_size(kv_self.v)*n_embd_head, + n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); @@ -1571,7 +1629,7 @@ static bool llama_eval_internal( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -1605,7 +1663,7 @@ static bool llama_eval_internal( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_1"); @@ -1658,7 +1716,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2"); @@ -1680,16 +1738,22 @@ static bool llama_eval_internal( //cur = ggml_soft_max_inplace(ctx0, cur); // run the computation - ggml_build_forward_expand(&gf, cur); + ggml_build_forward_expand(gf, cur); + + // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs); #if GGML_USE_MPI - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer); + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); #endif #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { + // TODO: disabled until #2413 is resolved + //if (!ggml_metal_if_optimized(lctx.ctx_metal)) { + // ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf); + //} ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, &gf); + ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_get_tensor (lctx.ctx_metal, cur); } else { // IMPORTANT: @@ -1708,34 +1772,34 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); } - ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } #else - ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); #endif #if GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer); + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif // update kv token count lctx.kv_self.n = n_past + N; - struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1]; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; if (cgraph_fname) { - ggml_graph_export(&gf, cgraph_fname); + ggml_graph_export(gf, cgraph_fname); } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) // requires GGML_PERF to be defined - ggml_graph_print(&gf); + ggml_graph_print(gf); #endif // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { - // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} // extract logits @@ -1765,10 +1829,12 @@ static bool llama_eval_internal( } #if 0 - printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, + printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, lctx.get_buf_max_mem(0)/1024.0/1024.0, - lctx.get_buf_max_mem(1)/1024.0/1024.0); + lctx.get_buf_max_mem(1)/1024.0/1024.0, + lctx.work_buffer.size()/1024.0/1024.0, + n_past, N); #endif ggml_free(ctx0); @@ -1941,6 +2007,279 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } +// +// grammar - internal +// + +struct llama_grammar { + const std::vector> rules; + std::vector> stacks; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; +}; + +// NOTE: assumes valid utf8 (but checks for overrun) +// adds a terminating 0 for use as pointer +std::vector decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = pos + len; // may overrun! + ++pos; + for ( ; pos < end && *pos != 0; ++pos) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + code_points.push_back(value); + } + code_points.push_back(0); + return code_points; +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; + case LLAMA_GRETYPE_ALT: return true; + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + LLAMA_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + // accept nothing; EOS is handled elsewhere + rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { + if (tok.code_points[1] != 0) { + next_candidates.push_back({ tok.index, tok.code_points + 1 }); + } + } else { + rejects.push_back(tok); + } + } + + auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1 }); + } + + return rejects; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + LLAMA_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + // // sampling // @@ -2226,6 +2565,47 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = llama_token_eos(); + + std::vector> candidates_decoded; + std::vector candidates_grammar; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (*str == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(str)); + candidates_grammar.push_back({ i, candidates_decoded.back().data() }); + } + } + + const auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -2244,8 +2624,7 @@ void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor) { + float scale) { int64_t t_start_sample_us = ggml_time_us(); assert(ctx); @@ -2266,16 +2645,7 @@ void llama_sample_classifier_free_guidance( for (int i = 0; i < n_vocab; ++i) { float logit_guidance = logits_guidance[i]; float logit_base = logits_base[i]; - logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; - } - - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - float logit_base = logits_base[i]; - float logit_guidance = logits_guidance[i]; - - candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; + candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; } if (ctx) { @@ -2411,6 +2781,29 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos()) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + LLAMA_ASSERT(false); + } + + const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string + auto code_points = decode_utf8(str); + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + } + LLAMA_ASSERT(!grammar->stacks.empty()); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + // // quantization // @@ -2484,8 +2877,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; #ifdef GGML_USE_K_QUANTS // K-quants @@ -2569,16 +2962,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - bool convert_incompatible_tensor = false; - if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K || - quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) { - int nx = tensor.ne.at(0); - int ny = tensor.ne.at(1); - if (nx % QK_K != 0 || ny % QK_K != 0) { - fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); - convert_incompatible_tensor = true; - } - } if (tensor.name == "output.weight") { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); @@ -2604,6 +2987,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor.ne.at(0); + int ny = tensor.ne.at(1); + if (nx % QK_K != 0 || ny % QK_K != 0) { + fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); + convert_incompatible_tensor = true; + } + } if (convert_incompatible_tensor) { if (tensor.name == "output.weight") { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. @@ -2630,7 +3023,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.addr; } - printf("quantizing .. "); + printf("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); work.resize(nelements * 4); // upper bound on size @@ -2733,7 +3126,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { @@ -2811,7 +3204,7 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } - ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); + ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); @@ -2835,7 +3228,7 @@ struct llama_context * llama_new_context_with_model( const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + fprintf(stderr, "%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); #define LLAMA_METAL_CHECK_BUF(result) \ if (!(result)) { \ diff --git a/llama/llama.go b/llama/llama.go index e2c30f1f..c7bf194a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -127,6 +127,7 @@ func New(model string, opts api.Options) (*LLM, error) { params.seed = C.uint(llm.Seed) params.n_ctx = C.int(llm.NumCtx) params.n_batch = C.int(llm.NumBatch) + params.n_gqa = C.int(llm.NumGQA) params.n_gpu_layers = C.int(llm.NumGPU) params.main_gpu = C.int(llm.MainGPU) params.low_vram = C.bool(llm.LowVRAM) diff --git a/llama/llama.h b/llama/llama.h index d3cb65ea..4a92566f 100644 --- a/llama/llama.h +++ b/llama/llama.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git e782c9e735f93ab4767ffc37462c523b73a17ddc + * llama.cpp - git 7c529cede6e84054e77a3eceab31c53de7b2f55b * * MIT License * @@ -79,6 +79,10 @@ #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif +#ifndef LLAMA_DEFAULT_RMS_EPS +#define LLAMA_DEFAULT_RMS_EPS 5e-6f +#endif + #ifdef __cplusplus extern "C" { #endif @@ -109,12 +113,15 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - int32_t n_ctx; // text context - int32_t n_batch; // prompt processing batch size - int32_t n_gpu_layers; // number of layers to store in VRAM - int32_t main_gpu; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + uint32_t seed; // RNG seed, -1 for random + int32_t n_ctx; // text context + int32_t n_batch; // prompt processing batch size + int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) + float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams) + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors + + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency @@ -165,6 +172,40 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + // performance timing information struct llama_timings { double t_start_ms; @@ -357,6 +398,15 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + // Grammar + // + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -369,13 +419,11 @@ extern "C" { /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. LLAMA_API void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor); + float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); @@ -393,6 +441,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -414,6 +465,9 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx);