diff --git a/api/client.go b/api/client.go index d339da56..57162571 100644 --- a/api/client.go +++ b/api/client.go @@ -210,6 +210,13 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { return &lr, nil } +func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { + if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil { + return err + } + return nil +} + func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil { return err diff --git a/api/types.go b/api/types.go index e9e72d83..cabec90a 100644 --- a/api/types.go +++ b/api/types.go @@ -48,6 +48,11 @@ type DeleteRequest struct { Name string `json:"name"` } +type CopyRequest struct { + Source string `json:"source"` + Destination string `json:"destination"` +} + type PullRequest struct { Name string `json:"name"` Insecure bool `json:"insecure,omitempty"` diff --git a/cmd/cmd.go b/cmd/cmd.go index f94e3964..693d0cab 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -155,14 +155,25 @@ func ListHandler(cmd *cobra.Command, args []string) error { func DeleteHandler(cmd *cobra.Command, args []string) error { client := api.NewClient() - request := api.DeleteRequest{Name: args[0]} - if err := client.Delete(context.Background(), &request); err != nil { + req := api.DeleteRequest{Name: args[0]} + if err := client.Delete(context.Background(), &req); err != nil { return err } fmt.Printf("deleted '%s'\n", args[0]) return nil } +func CopyHandler(cmd *cobra.Command, args []string) error { + client := api.NewClient() + + req := api.CopyRequest{Source: args[0], Destination: args[1]} + if err := client.Copy(context.Background(), &req); err != nil { + return err + } + fmt.Printf("copied '%s' to '%s'\n", args[0], args[1]) + return nil +} + func PullHandler(cmd *cobra.Command, args []string) error { insecure, err := cmd.Flags().GetBool("insecure") if err != nil { @@ -470,6 +481,13 @@ func NewCLI() *cobra.Command { RunE: ListHandler, } + copyCmd := &cobra.Command{ + Use: "cp", + Short: "Copy a model", + Args: cobra.MinimumNArgs(2), + RunE: CopyHandler, + } + deleteCmd := &cobra.Command{ Use: "rm", Short: "Remove a model", @@ -484,6 +502,7 @@ func NewCLI() *cobra.Command { pullCmd, pushCmd, listCmd, + copyCmd, deleteCmd, ) diff --git a/server/images.go b/server/images.go index c0b3996e..babfca54 100644 --- a/server/images.go +++ b/server/images.go @@ -493,6 +493,32 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { return layer, nil } +func CopyModel(src, dest string) error { + srcPath, err := ParseModelPath(src).GetManifestPath(false) + if err != nil { + return err + } + destPath, err := ParseModelPath(dest).GetManifestPath(true) + if err != nil { + return err + } + + // copy the file + input, err := ioutil.ReadFile(srcPath) + if err != nil { + fmt.Println("Error reading file:", err) + return err + } + + err = ioutil.WriteFile(destPath, input, 0644) + if err != nil { + fmt.Println("Error reading file:", err) + return err + } + + return nil +} + func DeleteModel(name string) error { mp := ParseModelPath(name) diff --git a/server/routes.go b/server/routes.go index 832c303c..d5b2e127 100644 --- a/server/routes.go +++ b/server/routes.go @@ -228,6 +228,23 @@ func ListModelsHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{models}) } +func CopyModelHandler(c *gin.Context) { + var req api.CopyRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := CopyModel(req.Source, req.Destination); err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)}) + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } +} + func Serve(ln net.Listener) error { config := cors.DefaultConfig() config.AllowWildcard = true @@ -254,6 +271,7 @@ func Serve(ln net.Listener) error { r.POST("/api/generate", GenerateHandler) r.POST("/api/create", CreateModelHandler) r.POST("/api/push", PushModelHandler) + r.POST("/api/copy", CopyModelHandler) r.GET("/api/tags", ListModelsHandler) r.DELETE("/api/delete", DeleteModelHandler)