Fix no slots available error with concurrent requests (#4160)

This commit is contained in:
Jeffrey Morgan 2024-05-06 14:22:53 -07:00 committed by GitHub
parent c9f98622b1
commit ed740a2504
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -338,7 +338,7 @@ type ServerStatus int
const ( // iota is reset to 0 const ( // iota is reset to 0
ServerStatusReady ServerStatus = iota ServerStatusReady ServerStatus = iota
ServerStatusNoSlotsAvaialble ServerStatusNoSlotsAvailable
ServerStatusLoadingModel ServerStatusLoadingModel
ServerStatusNotResponding ServerStatusNotResponding
ServerStatusError ServerStatusError
@ -348,7 +348,7 @@ func (s ServerStatus) ToString() string {
switch s { switch s {
case ServerStatusReady: case ServerStatusReady:
return "llm server ready" return "llm server ready"
case ServerStatusNoSlotsAvaialble: case ServerStatusNoSlotsAvailable:
return "llm busy - no slots available" return "llm busy - no slots available"
case ServerStatusLoadingModel: case ServerStatusLoadingModel:
return "llm server loading model" return "llm server loading model"
@ -405,7 +405,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
case "ok": case "ok":
return ServerStatusReady, nil return ServerStatusReady, nil
case "no slot available": case "no slot available":
return ServerStatusNoSlotsAvaialble, nil return ServerStatusNoSlotsAvailable, nil
case "loading model": case "loading model":
return ServerStatusLoadingModel, nil return ServerStatusLoadingModel, nil
default: default:
@ -413,6 +413,29 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
} }
} }
// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
var retries int
for {
status, err := s.getServerStatus(ctx)
if err != nil {
return status, err
}
if status == ServerStatusNoSlotsAvailable {
if retries >= 10 {
return status, fmt.Errorf("no slots available after %d retries", retries)
}
time.Sleep(5 * time.Millisecond)
retries++
continue
}
return status, nil
}
}
func (s *llmServer) Ping(ctx context.Context) error { func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx) _, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
@ -510,7 +533,6 @@ ws ::= ([ \t\n] ws)?
` `
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
type ImageData struct { type ImageData struct {
Data []byte `json:"data"` Data []byte `json:"data"`
@ -586,7 +608,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
@ -600,13 +622,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
} }
retryDelay := 100 * time.Microsecond
for retries := 0; retries < maxRetries; retries++ {
if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
retryDelay *= 2 // exponential backoff
}
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
@ -617,20 +632,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil { if err != nil {
return fmt.Errorf("error creating POST request: %v", err) return fmt.Errorf("error creating POST request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") serverReq.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(serverReq)
if err != nil { if err != nil {
return fmt.Errorf("POST predict: %v", err) return fmt.Errorf("POST predict: %v", err)
} }
defer resp.Body.Close() defer res.Body.Close()
if resp.StatusCode >= 400 { if res.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err) return fmt.Errorf("failed reading llm error response: %w", err)
} }
@ -638,11 +653,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("%s", bodyBytes) return fmt.Errorf("%s", bodyBytes)
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(res.Body)
buf := make([]byte, 0, maxBufferSize) buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize) scanner.Buffer(buf, maxBufferSize)
retryNeeded := false
// keep track of the last token generated, this is used to abort if the model starts looping // keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string var lastToken string
var tokenRepeat int var tokenRepeat int
@ -658,12 +672,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue continue
} }
// try again on slot unavailable
if bytes.Contains(line, []byte("slot unavailable")) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: ")) evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok { if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line) return fmt.Errorf("error parsing llm response stream: %s", line)
@ -714,19 +722,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
return fmt.Errorf("an unknown error was encountered while running the model %s", msg) return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
} }
return fmt.Errorf("error reading llm response: %v", err) return fmt.Errorf("error reading llm response: %v", err)
} }
if !retryNeeded { return nil
return nil // success
}
}
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
@ -743,8 +745,9 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, err return nil, err
} }
defer s.sem.Release(1) defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
@ -799,7 +802,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
@ -851,7 +854,7 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return "", err return "", err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString()) return "", fmt.Errorf("unexpected server status: %s", status.ToString())
} }