diff --git a/api/client.go b/api/client.go index 974c08eb..44af222c 100644 --- a/api/client.go +++ b/api/client.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData var reqBody io.Reader var data []byte var err error - if reqData != nil { + + switch reqData := reqData.(type) { + case io.Reader: + // reqData is already an io.Reader + reqBody = reqData + case nil: + // noop + default: data, err = json.Marshal(reqData) if err != nil { return err } + reqBody = bytes.NewReader(data) } @@ -287,3 +296,18 @@ func (c *Client) Heartbeat(ctx context.Context) error { } return nil } + +func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { + if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil { + var statusError StatusError + if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound { + return err + } + + if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil { + return err + } + } + + return nil +} diff --git a/api/types.go b/api/types.go index ffa5b7ca..2a36a1f6 100644 --- a/api/types.go +++ b/api/types.go @@ -99,9 +99,10 @@ type EmbeddingResponse struct { } type CreateRequest struct { - Name string `json:"name"` - Path string `json:"path"` - Stream *bool `json:"stream,omitempty"` + Name string `json:"name"` + Path string `json:"path"` + Modelfile string `json:"modelfile"` + Stream *bool `json:"stream,omitempty"` } type DeleteRequest struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index 8fc6e4c4..008c6b38 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,9 +1,11 @@ package cmd import ( + "bytes" "context" "crypto/ed25519" "crypto/rand" + "crypto/sha256" "encoding/pem" "errors" "fmt" @@ -27,6 +29,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/progressbar" "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" @@ -45,17 +48,64 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - var spinner *Spinner + modelfile, err := os.ReadFile(filename) + if err != nil { + return err + } + + spinner := NewSpinner("transferring context") + go spinner.Spin(100 * time.Millisecond) + + commands, err := parser.Parse(bytes.NewReader(modelfile)) + if err != nil { + return err + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + for _, c := range commands { + switch c.Name { + case "model", "adapter": + path := c.Args + if path == "~" { + path = home + } else if strings.HasPrefix(path, "~/") { + path = filepath.Join(home, path[2:]) + } + + bin, err := os.Open(path) + if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + continue + } else if err != nil { + return err + } + defer bin.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, bin); err != nil { + return err + } + bin.Seek(0, io.SeekStart) + + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) + if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + return err + } + + modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest)) + } + } var currentDigest string var bar *progressbar.ProgressBar - request := api.CreateRequest{Name: args[0], Path: filename} + request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)} fn := func(resp api.ProgressResponse) error { if resp.Digest != currentDigest && resp.Digest != "" { - if spinner != nil { - spinner.Stop() - } + spinner.Stop() currentDigest = resp.Digest // pulling bar = progressbar.DefaultBytes( @@ -67,9 +117,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { bar.Set64(resp.Completed) } else { currentDigest = "" - if spinner != nil { - spinner.Stop() - } + spinner.Stop() spinner = NewSpinner(resp.Status) go spinner.Spin(100 * time.Millisecond) } @@ -81,11 +129,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - if spinner != nil { - spinner.Stop() - if spinner.description != "success" { - return errors.New("unexpected end to create model") - } + spinner.Stop() + if spinner.description != "success" { + return errors.New("unexpected end to create model") } return nil diff --git a/docs/api.md b/docs/api.md index 08402266..9bb4d378 100644 --- a/docs/api.md +++ b/docs/api.md @@ -292,12 +292,13 @@ curl -X POST http://localhost:11434/api/generate -d '{ POST /api/create ``` -Create a model from a [`Modelfile`](./modelfile.md) +Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation should also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response. ### Parameters - `name`: name of the model to create -- `path`: path to the Modelfile +- `path`: path to the Modelfile (deprecated: please use modelfile instead) +- `modelfile`: contents of the Modelfile - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects ### Examples @@ -307,7 +308,8 @@ Create a model from a [`Modelfile`](./modelfile.md) ```shell curl -X POST http://localhost:11434/api/create -d '{ "name": "mario", - "path": "~/Modelfile" + "path": "~/Modelfile", + "modelfile": "FROM llama2" }' ``` @@ -321,6 +323,54 @@ A stream of JSON objects. When finished, `status` is `success`. } ``` +### Check if a Blob Exists + +```shell +HEAD /api/blobs/:digest +``` + +Check if a blob is known to the server. + +#### Query Parameters + +- `digest`: the SHA256 digest of the blob + +#### Examples + +##### Request + +```shell +curl -I http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 +``` + +##### Response + +Return 200 OK if the blob exists, 404 Not Found if it does not. + +### Create a Blob + +```shell +POST /api/blobs/:digest +``` + +Create a blob from a file. Returns the server file path. + +#### Query Parameters + +- `digest`: the expected SHA256 digest of the file + +#### Examples + +##### Request + +```shell +curl -T model.bin -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 +``` + +##### Response + +Return 201 Created if the blob was successfully created. + ## List Local Models ```shell diff --git a/server/images.go b/server/images.go index 8d784fef..d8ff0fd8 100644 --- a/server/images.go +++ b/server/images.go @@ -248,200 +248,181 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { - mp := ParseModelPath(name) - - var manifest *ManifestV2 - var err error - var noprune string - - // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) - - if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - manifest, _, err = GetManifest(mp) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - - if manifest != nil { - for _, l := range manifest.Layers { - deleteMap[l.Digest] = true - } - deleteMap[manifest.Config.Digest] = true - } - } - - mf, err := os.Open(path) +func realpath(p string) string { + abspath, err := filepath.Abs(p) if err != nil { - fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) - return fmt.Errorf("failed to open file: %w", err) + return p } - defer mf.Close() - fn(api.ProgressResponse{Status: "parsing modelfile"}) - commands, err := parser.Parse(mf) + home, err := os.UserHomeDir() if err != nil { - return err + return abspath } + if p == "~" { + return home + } else if strings.HasPrefix(p, "~/") { + return filepath.Join(home, p[2:]) + } + + return abspath +} + +func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { config := ConfigV2{ - Architecture: "amd64", OS: "linux", + Architecture: "amd64", } + deleteMap := make(map[string]struct{}) + var layers []*LayerReader + params := make(map[string][]string) - var sourceParams map[string]any + fromParams := make(map[string]any) + for _, c := range commands { - log.Printf("[%s] - %s\n", c.Name, c.Args) + log.Printf("[%s] - %s", c.Name, c.Args) + mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + switch c.Name { case "model": - fn(api.ProgressResponse{Status: "looking for model"}) + if strings.HasPrefix(c.Args, "@") { + blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + if err != nil { + return err + } - mp := ParseModelPath(c.Args) - mf, _, err := GetManifest(mp) + c.Args = blobPath + } + + bin, err := os.Open(realpath(c.Args)) if err != nil { - modelFile, err := filenameWithPath(path, c.Args) - if err != nil { + // not a file on disk so must be a model reference + modelpath := ParseModelPath(c.Args) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + fn(api.ProgressResponse{Status: "pulling model"}) + if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + return err + } + + manifest, _, err = GetManifest(modelpath) + if err != nil { + return err + } + case err != nil: return err } - if _, err := os.Stat(modelFile); err != nil { - // the model file does not exist, try pulling it - if errors.Is(err, os.ErrNotExist) { - fn(api.ProgressResponse{Status: "pulling model file"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { - return err - } - mf, _, err = GetManifest(mp) - if err != nil { - return fmt.Errorf("failed to open file after pull: %v", err) - } - } else { - return err - } - } else { - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) - file, err := os.Open(modelFile) - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - defer file.Close() - ggml, err := llm.DecodeGGML(file) - if err != nil { - return err - } - - config.ModelFormat = ggml.Name() - config.ModelFamily = ggml.ModelFamily() - config.ModelType = ggml.ModelType() - config.FileType = ggml.FileType() - - // reset the file - file.Seek(0, io.SeekStart) - - l, err := CreateLayer(file) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.model" - layers = append(layers, l) - } - } - - if mf != nil { fn(api.ProgressResponse{Status: "reading model metadata"}) - sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) + fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) if err != nil { return err } - sourceBlob, err := os.Open(sourceBlobPath) + fromConfigFile, err := os.Open(fromConfigPath) if err != nil { return err } - defer sourceBlob.Close() + defer fromConfigFile.Close() - var source ConfigV2 - if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { + var fromConfig ConfigV2 + if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { return err } - // copy the model metadata - config.ModelFamily = source.ModelFamily - config.ModelType = source.ModelType - config.ModelFormat = source.ModelFormat - config.FileType = source.FileType + config.ModelFormat = fromConfig.ModelFormat + config.ModelFamily = fromConfig.ModelFamily + config.ModelType = fromConfig.ModelType + config.FileType = fromConfig.FileType - for _, l := range mf.Layers { - if l.MediaType == "application/vnd.ollama.image.params" { - sourceParamsBlobPath, err := GetBlobsPath(l.Digest) + for _, layer := range manifest.Layers { + deleteMap[layer.Digest] = struct{}{} + if layer.MediaType == "application/vnd.ollama.image.params" { + fromParamsPath, err := GetBlobsPath(layer.Digest) if err != nil { return err } - sourceParamsBlob, err := os.Open(sourceParamsBlobPath) + fromParamsFile, err := os.Open(fromParamsPath) if err != nil { return err } - defer sourceParamsBlob.Close() + defer fromParamsFile.Close() - if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { + if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { return err } } - newLayer, err := GetLayerWithBufferFromLayer(l) + layer, err := GetLayerWithBufferFromLayer(layer) if err != nil { return err } - newLayer.From = mp.GetShortTagname() - layers = append(layers, newLayer) - } - } - case "adapter": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - fp, err := filenameWithPath(path, c.Args) + layer.From = modelpath.GetShortTagname() + layers = append(layers, layer) + } + + deleteMap[manifest.Config.Digest] = struct{}{} + continue + } + defer bin.Close() + + fn(api.ProgressResponse{Status: "creating model layer"}) + ggml, err := llm.DecodeGGML(bin) if err != nil { return err } - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) + config.ModelFormat = ggml.Name() + config.ModelFamily = ggml.ModelFamily() + config.ModelType = ggml.ModelType() + config.FileType = ggml.FileType() - file, err := os.Open(fp) + bin.Seek(0, io.SeekStart) + layer, err := CreateLayer(bin) if err != nil { - return fmt.Errorf("failed to open file: %v", err) + return err } - defer file.Close() - l, err := CreateLayer(file) + layer.MediaType = mediatype + layers = append(layers, layer) + case "adapter": + fn(api.ProgressResponse{Status: "creating adapter layer"}) + bin, err := os.Open(realpath(c.Args)) if err != nil { - return fmt.Errorf("failed to create layer: %v", err) + return err } - l.MediaType = "application/vnd.ollama.image.adapter" - layers = append(layers, l) - case "license": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + defer bin.Close() - layer, err := CreateLayer(strings.NewReader(c.Args)) + layer, err := CreateLayer(bin) if err != nil { return err } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } - case "template", "system", "prompt": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the layer if one exists - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) - layers = removeLayerFromLayers(layers, mediaType) + case "license": + fn(api.ProgressResponse{Status: "creating license layer"}) + layer, err := CreateLayer(strings.NewReader(c.Args)) + if err != nil { + return err + } + + if layer.Size > 0 { + layer.MediaType = mediatype + layers = append(layers, layer) + } + case "template", "system": + fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) + + // remove duplicate layers + layers = removeLayerFromLayers(layers, mediatype) layer, err := CreateLayer(strings.NewReader(c.Args)) if err != nil { @@ -449,48 +430,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } default: - // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences) params[c.Name] = append(params[c.Name], c.Args) } } - // Create a single layer for the parameters if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameter layer"}) + fn(api.ProgressResponse{Status: "creating parameters layer"}) - layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") formattedParams, err := formatParams(params) if err != nil { - return fmt.Errorf("couldn't create params json: %v", err) + return err } - for k, v := range sourceParams { + for k, v := range fromParams { if _, ok := formattedParams[k]; !ok { formattedParams[k] = v } } if config.ModelType == "65B" { - if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { + if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 { config.ModelType = "70B" } } - bts, err := json.Marshal(formattedParams) + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + return err + } + + fn(api.ProgressResponse{Status: "creating config layer"}) + layer, err := CreateLayer(bytes.NewReader(b.Bytes())) if err != nil { return err } - l, err := CreateLayer(bytes.NewReader(bts)) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.params" - layers = append(layers, l) + layer.MediaType = "application/vnd.ollama.image.params" + layers = append(layers, layer) } digests, err := getLayerDigests(layers) @@ -498,36 +478,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } - var manifestLayers []*Layer - for _, l := range layers { - manifestLayers = append(manifestLayers, &l.Layer) - delete(deleteMap, l.Layer.Digest) - } - - // Create a layer for the config object - fn(api.ProgressResponse{Status: "creating config layer"}) - cfg, err := createConfigLayer(config, digests) + configLayer, err := createConfigLayer(config, digests) if err != nil { return err } - layers = append(layers, cfg) - delete(deleteMap, cfg.Layer.Digest) + + layers = append(layers, configLayer) + delete(deleteMap, configLayer.Digest) if err := SaveLayers(layers, fn, false); err != nil { return err } - // Create the manifest + var contentLayers []*Layer + for _, layer := range layers { + contentLayers = append(contentLayers, &layer.Layer) + delete(deleteMap, layer.Digest) + } + fn(api.ProgressResponse{Status: "writing manifest"}) - err = CreateManifest(name, cfg, manifestLayers) - if err != nil { + if err := CreateManifest(name, configLayer, contentLayers); err != nil { return err } - if noprune == "" { - fn(api.ProgressResponse{Status: "removing any unused layers"}) - err = deleteUnusedLayers(nil, deleteMap, false) - if err != nil { + if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { return err } } @@ -739,7 +714,7 @@ func CopyModel(src, dest string) error { return nil } -func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { +func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { fp, err := GetManifestPath() if err != nil { return err @@ -779,21 +754,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } // only delete the files which are still in the deleteMap - for k, v := range deleteMap { - if v { - fp, err := GetBlobsPath(k) - if err != nil { - log.Printf("couldn't get file path for '%s': %v", k, err) + for k := range deleteMap { + fp, err := GetBlobsPath(k) + if err != nil { + log.Printf("couldn't get file path for '%s': %v", k, err) + continue + } + if !dryRun { + if err := os.Remove(fp); err != nil { + log.Printf("couldn't remove file '%s': %v", fp, err) continue } - if !dryRun { - if err := os.Remove(fp); err != nil { - log.Printf("couldn't remove file '%s': %v", fp, err) - continue - } - } else { - log.Printf("wanted to remove: %s", fp) - } + } else { + log.Printf("wanted to remove: %s", fp) } } @@ -801,7 +774,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } func PruneLayers() error { - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) p, err := GetBlobsPath("") if err != nil { return err @@ -818,7 +791,9 @@ func PruneLayers() error { if runtime.GOOS == "windows" { name = strings.ReplaceAll(name, "-", ":") } - deleteMap[name] = true + if strings.HasPrefix(name, "sha256:") { + deleteMap[name] = struct{}{} + } } log.Printf("total blobs: %d", len(deleteMap)) @@ -873,11 +848,11 @@ func DeleteModel(name string) error { return err } - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = true + deleteMap[layer.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} err = deleteUnusedLayers(&mp, deleteMap, false) if err != nil { @@ -1013,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu var noprune string // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { manifest, _, err = GetManifest(mp) @@ -1023,9 +998,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu if manifest != nil { for _, l := range manifest.Layers { - deleteMap[l.Digest] = true + deleteMap[l.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} } } diff --git a/server/routes.go b/server/routes.go index a543b10e..58145576 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" @@ -26,6 +27,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llm" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -409,8 +411,31 @@ func CreateModelHandler(c *gin.Context) { return } - if req.Name == "" || req.Path == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"}) + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) + return + } + + if req.Path == "" && req.Modelfile == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) + return + } + + var modelfile io.Reader = strings.NewReader(req.Modelfile) + if req.Path != "" && req.Modelfile == "" { + bin, err := os.Open(req.Path) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) + return + } + defer bin.Close() + + modelfile = bin + } + + commands, err := parser.Parse(modelfile) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -424,7 +449,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { + if err := CreateModel(ctx, req.Name, commands, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -625,6 +650,60 @@ func CopyModelHandler(c *gin.Context) { } } +func HeadBlobHandler(c *gin.Context) { + path, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if _, err := os.Stat(path); err != nil { + c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))}) + return + } + + c.Status(http.StatusOK) +} + +func CreateBlobHandler(c *gin.Context) { + hash := sha256.New() + temp, err := os.CreateTemp("", c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"}) + return + } + + if err := temp.Close(); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + targetPath, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := os.Rename(temp.Name(), targetPath); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.Status(http.StatusCreated) +} + var defaultAllowOrigins = []string{ "localhost", "127.0.0.1", @@ -684,6 +763,8 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/copy", CopyModelHandler) r.DELETE("/api/delete", DeleteModelHandler) r.POST("/api/show", ShowModelHandler) + r.POST("/api/blobs/:digest", CreateBlobHandler) + r.HEAD("/api/blobs/:digest", HeadBlobHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) {