From 35af37a2cb7097dcbac2a0f88eb2636436f82d2a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Jul 2023 11:59:42 -0700 Subject: [PATCH] session id --- api/types.go | 8 +++++--- cmd/cmd.go | 24 ++++++++++++++-------- llama/llama.go | 18 ++++++++-------- server/routes.go | 53 +++++++++++++++++++++++++++++++++--------------- 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/api/types.go b/api/types.go index 07ce8122..42b0c470 100644 --- a/api/types.go +++ b/api/types.go @@ -28,9 +28,10 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Context []int `json:"context,omitempty"` + SessionID int64 `json:"session_id"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Context []int `json:"context,omitempty"` Options `json:"options"` } @@ -81,6 +82,7 @@ type ListResponseModel struct { } type GenerateResponse struct { + SessionID int64 `json:"session_id"` Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 7761b03b..b9c07cff 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error { return generateBatch(cmd, args[0]) } -var generateContextKey struct{} +type generateContextKey string func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { @@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error { var latest api.GenerateResponse - generateContext, ok := cmd.Context().Value(generateContextKey).([]int) + generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) if !ok { generateContext = []int{} } - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} - fn := func(resp api.GenerateResponse) error { + generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64) + if !ok { + generateSession = 0 + } + + request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession} + fn := func(response api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() } - latest = resp + latest = response - fmt.Print(resp.Response) - - cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context)) + fmt.Print(response.Response) return nil } @@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error { if verbose { latest.Summary() } + + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) + ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID) + cmd.SetContext(ctx) } return nil diff --git a/llama/llama.go b/llama/llama.go index 37b4c143..a48c5965 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -91,7 +91,7 @@ import ( "github.com/jmorganca/ollama/api" ) -type llama struct { +type LLM struct { params *C.struct_llama_context_params model *C.struct_llama_model ctx *C.struct_llama_context @@ -99,12 +99,12 @@ type llama struct { api.Options } -func New(model string, opts api.Options) (*llama, error) { +func New(model string, opts api.Options) (*LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err } - llm := llama{Options: opts} + llm := LLM{Options: opts} C.llama_backend_init(C.bool(llm.UseNUMA)) @@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) { return &llm, nil } -func (llm *llama) Close() { +func (llm *LLM) Close() { defer C.llama_free_model(llm.model) defer C.llama_free(llm.ctx) C.llama_print_timings(llm.ctx) } -func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { +func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { if input := llm.tokenize(prompt); input != nil { embd := make([]C.llama_token, len(ctx)) for i := range ctx { @@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse return errors.New("llama: tokenize") } -func (llm *llama) tokenize(prompt string) []C.llama_token { +func (llm *LLM) tokenize(prompt string) []C.llama_token { cPrompt := C.CString(prompt) defer C.free(unsafe.Pointer(cPrompt)) @@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token { return nil } -func (llm *llama) detokenize(tokens ...C.llama_token) string { +func (llm *LLM) detokenize(tokens ...C.llama_token) string { var sb strings.Builder for _, token := range tokens { sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) @@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string { return sb.String() } -func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { +func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { var opts C.struct_llama_sample_options opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty) @@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) return nil } -func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { +func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { numVocab := int(C.llama_n_vocab(llm.ctx)) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) diff --git a/server/routes.go b/server/routes.go index aabcb718..93a04cd7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "dario.cat/mergo" @@ -21,7 +22,17 @@ import ( "github.com/jmorganca/ollama/llama" ) +var mu sync.Mutex + +var activeSession struct { + ID int64 + *llama.LLM +} + func GenerateHandler(c *gin.Context) { + mu.Lock() + defer mu.Unlock() + start := time.Now() var req api.GenerateRequest @@ -36,15 +47,31 @@ func GenerateHandler(c *gin.Context) { return } - opts := api.DefaultOptions() - if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + if req.SessionID == 0 || req.SessionID != activeSession.ID { + if activeSession.LLM != nil { + activeSession.Close() + activeSession.LLM = nil + } - if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + opts := api.DefaultOptions() + if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + llm, err := llama.New(model.ModelPath, opts) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + activeSession.ID = time.Now().UnixNano() + activeSession.LLM = llm } prompt, err := model.Prompt(req) @@ -53,19 +80,13 @@ func GenerateHandler(c *gin.Context) { return } - llm, err := llama.New(model.ModelPath, opts) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer llm.Close() - ch := make(chan any) go func() { defer close(ch) fn := func(r api.GenerateResponse) { r.Model = req.Model r.CreatedAt = time.Now().UTC() + r.SessionID = activeSession.ID if r.Done { r.TotalDuration = time.Since(start) } @@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := llm.Predict(req.Context, prompt, fn); err != nil { + if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()