diff --git a/api/client.go b/api/client.go index f153f32e..65d36ecb 100644 --- a/api/client.go +++ b/api/client.go @@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc return err } + if resp.Error.Message != "" { + // couldn't pull the model from the directory, proceed anyway + return nil + } + return fn(resp) }), ) diff --git a/api/types.go b/api/types.go index 84dc3731..79bc2c24 100644 --- a/api/types.go +++ b/api/types.go @@ -26,6 +26,7 @@ type PullProgress struct { Total int64 `json:"total"` Completed int64 `json:"completed"` Percent float64 `json:"percent"` + Error Error `json:"error"` } type GenerateRequest struct { diff --git a/server/routes.go b/server/routes.go index 21a23785..684bfcf7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3,12 +3,14 @@ package server import ( "embed" "encoding/json" + "errors" "fmt" "io" "log" "math" "net" "net/http" + "os" "path" "runtime" "strings" @@ -25,6 +27,15 @@ import ( var templatesFS embed.FS var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) +func cacheDir() string { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return path.Join(home, ".ollama") +} + func generate(c *gin.Context) { var req api.GenerateRequest req.ModelOptions = api.DefaultModelOptions @@ -37,9 +48,16 @@ func generate(c *gin.Context) { if remoteModel, _ := getRemote(req.Model); remoteModel != nil { req.Model = remoteModel.FullName() } + if _, err := os.Stat(req.Model); err != nil { + if !errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } modelOpts := getModelOpts(req) - modelOpts.NGPULayers = 1 // hard-code this for now + modelOpts.NGPULayers = 1 // hard-code this for now model, err := llama.New(req.Model, modelOpts) if err != nil { @@ -118,6 +136,17 @@ func Serve(ln net.Listener) error { go func() { defer close(progressCh) if err := pull(req.Model, progressCh); err != nil { + var opError *net.OpError + if errors.As(err, &opError) { + result := api.PullProgress{ + Error: api.Error{ + Code: http.StatusBadGateway, + Message: "failed to get models from directory", + }, + } + c.JSON(http.StatusBadGateway, result) + return + } c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return }