From 1552cee59f6080fc8b74e81317e94381d2e1844a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 14:07:40 -0800 Subject: [PATCH] client create modelfile --- api/client.go | 27 +++++++++++++++++- api/types.go | 4 +++ cmd/cmd.go | 73 +++++++++++++++++++++++++++++++++++++++--------- server/routes.go | 57 +++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 14 deletions(-) diff --git a/api/client.go b/api/client.go index 974c08eb..262918b3 100644 --- a/api/client.go +++ b/api/client.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData var reqBody io.Reader var data []byte var err error - if reqData != nil { + + switch reqData := reqData.(type) { + case io.Reader: + // reqData is already an io.Reader + reqBody = reqData + case nil: + // noop + default: data, err = json.Marshal(reqData) if err != nil { return err } + reqBody = bytes.NewReader(data) } @@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error { } return nil } + +func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) { + var response CreateBlobResponse + if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil { + var statusError StatusError + if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound { + return "", err + } + + if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil { + return "", err + } + } + + return response.Path, nil +} diff --git a/api/types.go b/api/types.go index 2a36a1f6..347c4f84 100644 --- a/api/types.go +++ b/api/types.go @@ -105,6 +105,10 @@ type CreateRequest struct { Stream *bool `json:"stream,omitempty"` } +type CreateBlobResponse struct { + Path string `json:"path"` +} + type DeleteRequest struct { Name string `json:"name"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8fc6e4c4..30c6bcf6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,9 +1,11 @@ package cmd import ( + "bytes" "context" "crypto/ed25519" "crypto/rand" + "crypto/sha256" "encoding/pem" "errors" "fmt" @@ -27,6 +29,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/progressbar" "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" @@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - var spinner *Spinner + modelfile, err := os.ReadFile(filename) + if err != nil { + return err + } + + spinner := NewSpinner("transferring context") + go spinner.Spin(100 * time.Millisecond) + + commands, err := parser.Parse(bytes.NewReader(modelfile)) + if err != nil { + return err + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + for _, c := range commands { + switch c.Name { + case "model", "adapter": + path := c.Args + if path == "~" { + path = home + } else if strings.HasPrefix(path, "~/") { + path = filepath.Join(home, path[2:]) + } + + bin, err := os.Open(path) + if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + // value might be a model reference and not a real file + } else if err != nil { + return err + } + defer bin.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, bin); err != nil { + return err + } + bin.Seek(0, io.SeekStart) + + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) + path, err = client.CreateBlob(cmd.Context(), digest, bin) + if err != nil { + return err + } + + modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path)) + } + } var currentDigest string var bar *progressbar.ProgressBar - request := api.CreateRequest{Name: args[0], Path: filename} + request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)} fn := func(resp api.ProgressResponse) error { if resp.Digest != currentDigest && resp.Digest != "" { - if spinner != nil { - spinner.Stop() - } + spinner.Stop() currentDigest = resp.Digest // pulling bar = progressbar.DefaultBytes( @@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { bar.Set64(resp.Completed) } else { currentDigest = "" - if spinner != nil { - spinner.Stop() - } + spinner.Stop() spinner = NewSpinner(resp.Status) go spinner.Spin(100 * time.Millisecond) } @@ -81,11 +130,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - if spinner != nil { - spinner.Stop() - if spinner.description != "success" { - return errors.New("unexpected end to create model") - } + spinner.Stop() + if spinner.description != "success" { + return errors.New("unexpected end to create model") } return nil diff --git a/server/routes.go b/server/routes.go index 65a96911..c12a7cda 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" @@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) { } } +func GetBlobHandler(c *gin.Context) { + path, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if _, err := os.Stat(path); err != nil { + c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))}) + return + } + + c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path}) +} + +func CreateBlobHandler(c *gin.Context) { + hash := sha256.New() + temp, err := os.CreateTemp("", c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"}) + return + } + + if err := temp.Close(); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + targetPath, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := os.Rename(temp.Name(), targetPath); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath}) +} + var defaultAllowOrigins = []string{ "localhost", "127.0.0.1", @@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/copy", CopyModelHandler) r.DELETE("/api/delete", DeleteModelHandler) r.POST("/api/show", ShowModelHandler) + r.POST("/api/blobs/:digest", CreateBlobHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { @@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { }) r.Handle(method, "/api/tags", ListModelsHandler) + r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler) } log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)