do not reload the running llm when runtime params change (#840)

- only reload the running llm if the model has changed, or the options for loading the running model have changed
- rename loaded llm to runner to differentiate from loaded model image
- remove logic which keeps the first system prompt in the generation context
This commit is contained in:
Bruce MacDonald 2023-10-19 10:39:58 -04:00 committed by GitHub
parent 235e43d7f6
commit fe6f3b48f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 86 deletions

View file

@ -161,15 +161,10 @@ func (r *GenerateResponse) Summary() {
} }
} }
type Options struct { // Runner options which must be set when the model is loaded into memory
Seed int `json:"seed,omitempty"` type Runner struct {
// Backend options
UseNUMA bool `json:"numa,omitempty"` UseNUMA bool `json:"numa,omitempty"`
// Model options
NumCtx int `json:"num_ctx,omitempty"` NumCtx int `json:"num_ctx,omitempty"`
NumKeep int `json:"num_keep,omitempty"`
NumBatch int `json:"num_batch,omitempty"` NumBatch int `json:"num_batch,omitempty"`
NumGQA int `json:"num_gqa,omitempty"` NumGQA int `json:"num_gqa,omitempty"`
NumGPU int `json:"num_gpu,omitempty"` NumGPU int `json:"num_gpu,omitempty"`
@ -183,8 +178,15 @@ type Options struct {
EmbeddingOnly bool `json:"embedding_only,omitempty"` EmbeddingOnly bool `json:"embedding_only,omitempty"`
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"` RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
// Predict options type Options struct {
Runner
// Predict options used at runtime
NumKeep int `json:"num_keep,omitempty"`
Seed int `json:"seed,omitempty"`
NumPredict int `json:"num_predict,omitempty"` NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
@ -200,8 +202,6 @@ type Options struct {
MirostatEta float32 `json:"mirostat_eta,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
NumThread int `json:"num_thread,omitempty"`
} }
var ErrInvalidOpts = fmt.Errorf("invalid options") var ErrInvalidOpts = fmt.Errorf("invalid options")
@ -309,6 +309,7 @@ func DefaultOptions() Options {
PenalizeNewline: true, PenalizeNewline: true,
Seed: -1, Seed: -1,
Runner: Runner{
// options set when the model is loaded // options set when the model is loaded
NumCtx: 2048, NumCtx: 2048,
RopeFrequencyBase: 10000.0, RopeFrequencyBase: 10000.0,
@ -323,6 +324,7 @@ func DefaultOptions() Options {
UseMMap: true, UseMMap: true,
UseNUMA: false, UseNUMA: false,
EmbeddingOnly: true, EmbeddingOnly: true,
},
} }
} }

View file

@ -45,7 +45,6 @@ type Model struct {
System string System string
License []string License []string
Digest string Digest string
ConfigDigest string
Options map[string]interface{} Options map[string]interface{}
} }
@ -169,7 +168,6 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
ConfigDigest: manifest.Config.Digest,
Template: "{{ .Prompt }}", Template: "{{ .Prompt }}",
License: []string{}, License: []string{},
} }

View file

@ -46,13 +46,13 @@ func init() {
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
llm llm.LLM runner llm.LLM
expireAt time.Time expireAt time.Time
expireTimer *time.Timer expireTimer *time.Timer
digest string *Model
options api.Options *api.Options
} }
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
@ -70,59 +70,39 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
} }
// check if the loaded model is still running in a subprocess, in case something unexpected happened // check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.llm != nil { if loaded.runner != nil {
if err := loaded.llm.Ping(ctx); err != nil { if err := loaded.runner.Ping(ctx); err != nil {
log.Print("loaded llm process not responding, closing now") log.Print("loaded llm process not responding, closing now")
// the subprocess is no longer running, so close it // the subprocess is no longer running, so close it
loaded.llm.Close() loaded.runner.Close()
loaded.llm = nil loaded.runner = nil
loaded.digest = "" loaded.Model = nil
loaded.Options = nil
} }
} }
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) { needLoad := loaded.runner == nil || // is there a model loaded?
if loaded.llm != nil { loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
if needLoad {
if loaded.runner != nil {
log.Println("changing loaded model") log.Println("changing loaded model")
loaded.llm.Close() loaded.runner.Close()
loaded.llm = nil loaded.runner = nil
loaded.digest = "" loaded.Model = nil
loaded.Options = nil
} }
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
if err != nil { if err != nil {
return err return err
} }
// set cache values before modifying opts loaded.Model = model
loaded.llm = llmModel loaded.runner = llmRunner
loaded.digest = model.Digest loaded.Options = &opts
loaded.options = opts
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil {
return err
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil {
return err
}
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
if err != nil {
return err
}
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
if err != nil {
return err
}
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
llmModel.SetOptions(opts)
}
} }
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)
@ -136,13 +116,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
return return
} }
if loaded.llm == nil { if loaded.runner != nil {
return loaded.runner.Close()
} }
loaded.llm.Close() loaded.runner = nil
loaded.llm = nil loaded.Model = nil
loaded.digest = "" loaded.Options = nil
}) })
} }
@ -215,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
if req.Prompt == "" && req.Template == "" && req.System == "" { if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true} ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else { } else {
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
} }
@ -263,12 +243,12 @@ func EmbeddingHandler(c *gin.Context) {
return return
} }
if !loaded.options.EmbeddingOnly { if !loaded.Options.EmbeddingOnly {
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"}) c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
return return
} }
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt) embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
log.Printf("embedding generation failed: %v", err) log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -599,8 +579,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-signals <-signals
if loaded.llm != nil { if loaded.runner != nil {
loaded.llm.Close() loaded.runner.Close()
} }
os.RemoveAll(workDir) os.RemoveAll(workDir)
os.Exit(0) os.Exit(0)