From 592dae31c826cde394f721eac64ce5d4748f4ef0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Apr 2024 16:22:38 -0700 Subject: [PATCH] update copy to use model.Name --- server/images.go | 27 ++++++++++++--------------- server/routes.go | 39 ++++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/server/images.go b/server/images.go index dd44a0f4..7ba5134c 100644 --- a/server/images.go +++ b/server/images.go @@ -29,6 +29,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -701,36 +702,32 @@ func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string return path, nil } -func CopyModel(src, dest string) error { - srcModelPath := ParseModelPath(src) - srcPath, err := srcModelPath.GetManifestPath() +func CopyModel(src, dst model.Name) error { + manifests, err := GetManifestPath() if err != nil { return err } - destModelPath := ParseModelPath(dest) - destPath, err := destModelPath.GetManifestPath() - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + dstpath := filepath.Join(manifests, dst.FilepathNoBuild()) + if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil { return err } - // copy the file - input, err := os.ReadFile(srcPath) + srcpath := filepath.Join(manifests, src.FilepathNoBuild()) + srcfile, err := os.Open(srcpath) if err != nil { - fmt.Println("Error reading file:", err) return err } + defer srcfile.Close() - err = os.WriteFile(destPath, input, 0o644) + dstfile, err := os.Create(dstpath) if err != nil { - fmt.Println("Error reading file:", err) return err } + defer dstfile.Close() - return nil + _, err = io.Copy(dstfile, srcfile) + return err } func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { diff --git a/server/routes.go b/server/routes.go index 016deb34..a5ae3ff4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -29,6 +29,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -788,34 +789,34 @@ func (s *Server) ListModelsHandler(c *gin.Context) { } func (s *Server) CopyModelHandler(c *gin.Context) { - var req api.CopyRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.CopyRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Source == "" || req.Destination == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"}) + src := model.ParseName(r.Source) + if !src.IsValid() { + _ = c.Error(fmt.Errorf("source %q is invalid", r.Source)) + } + + dst := model.ParseName(r.Destination) + if !dst.IsValid() { + _ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination)) + } + + if len(c.Errors) > 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()}) return } - if err := ParseModelPath(req.Destination).Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if err := CopyModel(req.Source, req.Destination); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)}) - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) + } else if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } }