diff --git a/llm/llama.go b/llm/llama.go index 8aa0f300..6577abbc 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) { llm.Options = opts } -type GenerationSettings struct { - FrequencyPenalty float64 `json:"frequency_penalty"` - IgnoreEOS bool `json:"ignore_eos"` - LogitBias []interface{} `json:"logit_bias"` - Mirostat int `json:"mirostat"` - MirostatEta float64 `json:"mirostat_eta"` - MirostatTau float64 `json:"mirostat_tau"` - Model string `json:"model"` - NCtx int `json:"n_ctx"` - NKeep int `json:"n_keep"` - NPredict int `json:"n_predict"` - NProbs int `json:"n_probs"` - PenalizeNl bool `json:"penalize_nl"` - PresencePenalty float64 `json:"presence_penalty"` - RepeatLastN int `json:"repeat_last_n"` - RepeatPenalty float64 `json:"repeat_penalty"` - Seed uint32 `json:"seed"` - Stop []string `json:"stop"` - Stream bool `json:"stream"` - Temp float64 `json:"temp"` - TfsZ float64 `json:"tfs_z"` - TopK int `json:"top_k"` - TopP float64 `json:"top_p"` - TypicalP float64 `json:"typical_p"` -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type Prediction struct { +type prediction struct { Content string `json:"content"` Model string `json:"model"` Prompt string `json:"prompt"` Stop bool `json:"stop"` - Timings `json:"timings"` -} - -type PredictRequest struct { - Prompt string `json:"prompt"` - Stream bool `json:"stream"` - NPredict int `json:"n_predict"` - NKeep int `json:"n_keep"` - Temperature float32 `json:"temperature"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - TfsZ float32 `json:"tfs_z"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - PenalizeNl bool `json:"penalize_nl"` - Seed int `json:"seed"` - Stop []string `json:"stop,omitempty"` + Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` + } } const maxBufferSize = 512 * format.KiloByte @@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, nextContext.WriteString(prevConvo) nextContext.WriteString(prompt) - endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) - predReq := PredictRequest{ - Prompt: nextContext.String(), - Stream: true, - NPredict: llm.NumPredict, - NKeep: llm.NumKeep, - Temperature: llm.Temperature, - TopK: llm.TopK, - TopP: llm.TopP, - TfsZ: llm.TFSZ, - TypicalP: llm.TypicalP, - RepeatLastN: llm.RepeatLastN, - RepeatPenalty: llm.RepeatPenalty, - PresencePenalty: llm.PresencePenalty, - FrequencyPenalty: llm.FrequencyPenalty, - Mirostat: llm.Mirostat, - MirostatTau: llm.MirostatTau, - MirostatEta: llm.MirostatEta, - PenalizeNl: llm.PenalizeNewline, - Seed: llm.Seed, - Stop: llm.Stop, + request := map[string]any{ + "prompt": nextContext.String(), + "stream": true, + "n_predict": llm.NumPredict, + "n_keep": llm.NumKeep, + "temperature": llm.Temperature, + "top_k": llm.TopK, + "top_p": llm.TopP, + "tfs_z": llm.TFSZ, + "typical_p": llm.TypicalP, + "repeat_last_n": llm.RepeatLastN, + "repeat_penalty": llm.RepeatPenalty, + "presence_penalty": llm.PresencePenalty, + "frequency_penalty": llm.FrequencyPenalty, + "mirostat": llm.Mirostat, + "mirostat_tau": llm.MirostatTau, + "mirostat_eta": llm.MirostatEta, + "penalize_nl": llm.PenalizeNewline, + "seed": llm.Seed, + "stop": llm.Stop, } // Handling JSON marshaling with special characters unescaped. @@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(predReq); err != nil { + if err := enc.Encode(request); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) if err != nil { return fmt.Errorf("error creating POST request: %v", err) @@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, // This handles the request cancellation return ctx.Err() default: - line := scanner.Text() - if line == "" { + line := scanner.Bytes() + if len(line) == 0 { continue } - // Read data from the server-side event stream - if strings.HasPrefix(line, "data: ") { - evt := line[6:] - var p Prediction - if err := json.Unmarshal([]byte(evt), &p); err != nil { + if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok { + var p prediction + if err := json.Unmarshal(evt, &p); err != nil { return fmt.Errorf("error unmarshaling llm prediction response: %v", err) } @@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn(api.GenerateResponse{ Done: true, Context: embd, - PromptEvalCount: p.PromptN, - PromptEvalDuration: parseDurationMs(p.PromptMS), - EvalCount: p.PredictedN, - EvalDuration: parseDurationMs(p.PredictedMS), + PromptEvalCount: p.Timings.PromptN, + PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), + EvalCount: p.Timings.PredictedN, + EvalDuration: parseDurationMs(p.Timings.PredictedMS), }) return nil