diff --git a/api/client.go b/api/client.go index 9913e9a3..4453a90a 100644 --- a/api/client.go +++ b/api/client.go @@ -210,3 +210,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { } return &lr, nil } + +type DeleteProgressFunc func(ProgressResponse) error + +func (c *Client) Delete(ctx context.Context, req *DeleteRequest, fn DeleteProgressFunc) error { + return c.stream(ctx, http.MethodDelete, "/api/delete", req, func(bts []byte) error { + var resp ProgressResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} diff --git a/api/types.go b/api/types.go index b14d6811..11c9b8f9 100644 --- a/api/types.go +++ b/api/types.go @@ -37,6 +37,10 @@ type CreateProgress struct { Status string `json:"status"` } +type DeleteRequest struct { + Name string `json:"name"` +} + type PullRequest struct { Name string `json:"name"` Username string `json:"username"` @@ -44,10 +48,10 @@ type PullRequest struct { } type ProgressResponse struct { - Status string `json:"status"` - Digest string `json:"digest,omitempty"` - Total int `json:"total,omitempty"` - Completed int `json:"completed,omitempty"` + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int `json:"total,omitempty"` + Completed int `json:"completed,omitempty"` } type PushRequest struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index 09ac2618..4ce698a0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -25,7 +25,7 @@ import ( "github.com/jmorganca/ollama/server" ) -func create(cmd *cobra.Command, args []string) error { +func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) if err != nil { @@ -59,7 +59,7 @@ func create(cmd *cobra.Command, args []string) error { return nil } -func RunRun(cmd *cobra.Command, args []string) error { +func RunHandler(cmd *cobra.Command, args []string) error { mp := server.ParseModelPath(args[0]) fp, err := mp.GetManifestPath(false) if err != nil { @@ -86,7 +86,7 @@ func RunRun(cmd *cobra.Command, args []string) error { return RunGenerate(cmd, args) } -func push(cmd *cobra.Command, args []string) error { +func PushHandler(cmd *cobra.Command, args []string) error { client := api.NewClient() request := api.PushRequest{Name: args[0]} @@ -101,7 +101,7 @@ func push(cmd *cobra.Command, args []string) error { return nil } -func list(cmd *cobra.Command, args []string) error { +func ListHandler(cmd *cobra.Command, args []string) error { client := api.NewClient() models, err := client.List(context.Background()) @@ -131,7 +131,22 @@ func list(cmd *cobra.Command, args []string) error { return nil } -func RunPull(cmd *cobra.Command, args []string) error { +func DeleteHandler(cmd *cobra.Command, args []string) error { + client := api.NewClient() + + request := api.DeleteRequest{Name: args[0]} + fn := func(resp api.ProgressResponse) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Delete(context.Background(), &request, fn); err != nil { + return err + } + return nil +} + +func PullHandler(cmd *cobra.Command, args []string) error { return pull(args[0]) } @@ -290,7 +305,7 @@ func generateInteractive(cmd *cobra.Command, model string) error { switch { case strings.HasPrefix(line, "/list"): args := strings.Fields(line) - if err := list(cmd, args[1:]); err != nil { + if err := ListHandler(cmd, args[1:]); err != nil { return err } @@ -387,7 +402,7 @@ func NewCLI() *cobra.Command { Use: "create MODEL", Short: "Create a model from a Modelfile", Args: cobra.MinimumNArgs(1), - RunE: create, + RunE: CreateHandler, } createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")") @@ -396,7 +411,7 @@ func NewCLI() *cobra.Command { Use: "run MODEL [PROMPT]", Short: "Run a model", Args: cobra.MinimumNArgs(1), - RunE: RunRun, + RunE: RunHandler, } runCmd.Flags().Bool("verbose", false, "Show timings for response") @@ -412,21 +427,28 @@ func NewCLI() *cobra.Command { Use: "pull MODEL", Short: "Pull a model from a registry", Args: cobra.MinimumNArgs(1), - RunE: RunPull, + RunE: PullHandler, } pushCmd := &cobra.Command{ Use: "push MODEL", Short: "Push a model to a registry", Args: cobra.MinimumNArgs(1), - RunE: push, + RunE: PushHandler, } listCmd := &cobra.Command{ Use: "list", Aliases: []string{"ls"}, Short: "List models", - RunE: list, + RunE: ListHandler, + } + + deleteCmd := &cobra.Command{ + Use: "rm", + Short: "Remove a model", + Args: cobra.MinimumNArgs(1), + RunE: DeleteHandler, } rootCmd.AddCommand( @@ -436,6 +458,7 @@ func NewCLI() *cobra.Command { pullCmd, pushCmd, listCmd, + deleteCmd, ) return rootCmd diff --git a/server/images.go b/server/images.go index bec89f6d..c635d939 100644 --- a/server/images.go +++ b/server/images.go @@ -487,6 +487,83 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { return layer, nil } +func DeleteModel(name string, fn func(api.ProgressResponse)) error { + mp := ParseModelPath(name) + + manifest, err := GetManifest(mp) + if err != nil { + fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) + return err + } + deleteMap := make(map[string]bool) + for _, layer := range manifest.Layers { + deleteMap[layer.Digest] = true + } + deleteMap[manifest.Config.Digest] = true + + fp, err := GetManifestPath() + if err != nil { + fn(api.ProgressResponse{Status: "problem getting manifest path"}) + return err + } + err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error { + if err != nil { + fn(api.ProgressResponse{Status: "problem walking manifest dir"}) + return err + } + if !info.IsDir() { + path := path[len(fp)+1:] + slashIndex := strings.LastIndex(path, "/") + if slashIndex == -1 { + return nil + } + tag := path[:slashIndex] + ":" + path[slashIndex+1:] + fmp := ParseModelPath(tag) + + // skip the manifest we're trying to delete + if mp.GetFullTagname() == fmp.GetFullTagname() { + return nil + } + + // save (i.e. delete from the deleteMap) any files used in other manifests + manifest, err := GetManifest(fmp) + if err != nil { + log.Printf("skipping file: %s", fp) + return nil + } + for _, layer := range manifest.Layers { + delete(deleteMap, layer.Digest) + } + delete(deleteMap, manifest.Config.Digest) + } + return nil + }) + + // only delete the files which are still in the deleteMap + for k, v := range deleteMap { + if v { + err := os.Remove(k) + if err != nil { + log.Printf("couldn't remove file '%s': %v", k, err) + continue + } + } + } + + fp, err = mp.GetManifestPath(false) + if err != nil { + return err + } + err = os.Remove(fp) + if err != nil { + log.Printf("couldn't remove manifest file '%s': %v", fp, err) + return err + } + fn(api.ProgressResponse{Status: fmt.Sprintf("deleted '%s'", name)}) + + return nil +} + func PushModel(name, username, password string, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) diff --git a/server/routes.go b/server/routes.go index 4f7becb7..5b6846b0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,7 +18,7 @@ import ( "github.com/jmorganca/ollama/llama" ) -func generate(c *gin.Context) { +func GenerateHandler(c *gin.Context) { start := time.Now() var req api.GenerateRequest @@ -78,7 +78,7 @@ func generate(c *gin.Context) { streamResponse(c, ch) } -func pull(c *gin.Context) { +func PullModelHandler(c *gin.Context) { var req api.PullRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -100,7 +100,7 @@ func pull(c *gin.Context) { streamResponse(c, ch) } -func push(c *gin.Context) { +func PushModelHandler(c *gin.Context) { var req api.PushRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -122,7 +122,7 @@ func push(c *gin.Context) { streamResponse(c, ch) } -func create(c *gin.Context) { +func CreateModelHandler(c *gin.Context) { var req api.CreateRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) @@ -146,7 +146,30 @@ func create(c *gin.Context) { streamResponse(c, ch) } -func list(c *gin.Context) { +func DeleteModelHandler(c *gin.Context) { + var req api.DeleteRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ch := make(chan any) + go func() { + defer close(ch) + fn := func(r api.ProgressResponse) { + ch <- r + } + + if err := DeleteModel(req.Name, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + }() + + streamResponse(c, ch) +} + +func ListModelsHandler(c *gin.Context) { var models []api.ListResponseModel fp, err := GetManifestPath() if err != nil { @@ -199,11 +222,12 @@ func Serve(ln net.Listener) error { c.String(http.StatusOK, "Ollama is running") }) - r.POST("/api/pull", pull) - r.POST("/api/generate", generate) - r.POST("/api/create", create) - r.POST("/api/push", push) - r.GET("/api/tags", list) + r.POST("/api/pull", PullModelHandler) + r.POST("/api/generate", GenerateHandler) + r.POST("/api/create", CreateModelHandler) + r.POST("/api/push", PushModelHandler) + r.GET("/api/tags", ListModelsHandler) + r.DELETE("/api/delete", DeleteModelHandler) log.Printf("Listening on %s", ln.Addr()) s := &http.Server{