diff --git a/llama/llama.go b/llama/llama.go index c9c9c8fd..e5804e1f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -36,7 +36,7 @@ import ( ) type LLama struct { - ctx unsafe.Pointer + ctx unsafe.Pointer embeddings bool contextSize int } @@ -68,6 +68,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { if po.Tokens == 0 { po.Tokens = 99999999 } + defer C.free(unsafe.Pointer(input)) reverseCount := len(po.StopPrompts) reversePrompt := make([]*C.char, reverseCount) @@ -76,6 +77,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { cs := C.CString(s) reversePrompt[i] = cs pass = &reversePrompt[0] + defer C.free(unsafe.Pointer(cs)) } params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), @@ -88,13 +90,13 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { C.CString(po.MainGPU), C.CString(po.TensorSplit), C.bool(po.PromptCacheRO), ) + defer C.llama_free_params(params) + ret := C.eval(params, l.ctx, input) if ret != 0 { return fmt.Errorf("inference failed") } - C.llama_free_params(params) - return nil } @@ -109,6 +111,8 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { if po.Tokens == 0 { po.Tokens = 99999999 } + defer C.free(unsafe.Pointer(input)) + out := make([]byte, po.Tokens) reverseCount := len(po.StopPrompts) @@ -118,18 +122,32 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { cs := C.CString(s) reversePrompt[i] = cs pass = &reversePrompt[0] + defer C.free(unsafe.Pointer(cs)) } + cLogitBias := C.CString(po.LogitBias) + defer C.free(unsafe.Pointer(cLogitBias)) + + cPathPromptCache := C.CString(po.PathPromptCache) + defer C.free(unsafe.Pointer(cPathPromptCache)) + + cMainGPU := C.CString(po.MainGPU) + defer C.free(unsafe.Pointer(cMainGPU)) + + cTensorSplit := C.CString(po.TensorSplit) + defer C.free(unsafe.Pointer(cTensorSplit)) + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), - C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), - C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), - C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), cLogitBias, + cPathPromptCache, C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + cMainGPU, cTensorSplit, C.bool(po.PromptCacheRO), ) + defer C.llama_free_params(params) ret := C.llama_predict(params, l.ctx, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode)) if ret != 0 { @@ -145,8 +163,6 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { res = strings.TrimRight(res, s) } - C.llama_free_params(params) - if po.TokenCallback != nil { setCallback(l.ctx, nil) }