diff --git a/api/client.go b/api/client.go index 29ab2698..f19f6c5e 100644 --- a/api/client.go +++ b/api/client.go @@ -116,3 +116,29 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc return fn(resp) }) } + +type PushProgressFunc func(PushProgress) error + +func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { + var resp PushProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +type CreateProgressFunc func(CreateProgress) error + +func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { + var resp CreateProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} diff --git a/api/types.go b/api/types.go index 86d116f2..af6e0e41 100644 --- a/api/types.go +++ b/api/types.go @@ -7,16 +7,6 @@ import ( "time" ) -type PullRequest struct { - Model string `json:"model"` -} - -type PullProgress struct { - Total int64 `json:"total"` - Completed int64 `json:"completed"` - Percent float64 `json:"percent"` -} - type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -25,6 +15,43 @@ type GenerateRequest struct { Options `json:"options"` } +type CreateRequest struct { + Name string `json:"name"` + Path string `json:"path"` +} + +type CreateProgress struct { + Status string `json:"status"` +} + +type PullRequest struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` +} + +type PullProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int `json:"total,omitempty"` + Completed int `json:"completed,omitempty"` + Percent float64 `json:"percent,omitempty"` +} + +type PushRequest struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` +} + +type PushProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int `json:"total,omitempty"` + Completed int `json:"completed,omitempty"` + Percent float64 `json:"percent,omitempty"` +} + type GenerateResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 18a90b9a..99033614 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -30,6 +30,23 @@ func cacheDir() string { return filepath.Join(home, ".ollama") } +func create(cmd *cobra.Command, args []string) error { + filename, _ := cmd.Flags().GetString("file") + client := api.NewClient() + + request := api.CreateRequest{Name: args[0], Path: filename} + fn := func(resp api.CreateProgress) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Create(context.Background(), &request, fn); err != nil { + return err + } + + return nil +} + func RunRun(cmd *cobra.Command, args []string) error { _, err := os.Stat(args[0]) switch { @@ -51,25 +68,56 @@ func RunRun(cmd *cobra.Command, args []string) error { return RunGenerate(cmd, args) } +func push(cmd *cobra.Command, args []string) error { + client := api.NewClient() + + request := api.PushRequest{Name: args[0]} + fn := func(resp api.PushProgress) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Push(context.Background(), &request, fn); err != nil { + return err + } + return nil +} + +func RunPull(cmd *cobra.Command, args []string) error { + return pull(args[0]) +} + func pull(model string) error { client := api.NewClient() + var bar *progressbar.ProgressBar - return client.Pull( - context.Background(), - &api.PullRequest{Model: model}, - func(progress api.PullProgress) error { - if bar == nil { - if progress.Percent >= 100 { - // already downloaded - return nil - } - bar = progressbar.DefaultBytes(progress.Total) + currentLayer := "" + request := api.PullRequest{Name: model} + fn := func(resp api.PullProgress) error { + if resp.Digest != currentLayer && resp.Digest != "" { + if currentLayer != "" { + fmt.Println() } + currentLayer = resp.Digest + layerStr := resp.Digest[7:23] + "..." + bar = progressbar.DefaultBytes( + int64(resp.Total), + "pulling "+layerStr, + ) + } else if resp.Digest == currentLayer && resp.Digest != "" { + bar.Set(resp.Completed) + } else { + currentLayer = "" + fmt.Println(resp.Status) + } + return nil + } - return bar.Set64(progress.Completed) - }, - ) + if err := client.Pull(context.Background(), &request, fn); err != nil { + return err + } + return nil } func RunGenerate(cmd *cobra.Command, args []string) error { @@ -215,6 +263,15 @@ func NewCLI() *cobra.Command { cobra.EnableCommandSorting = false + createCmd := &cobra.Command{ + Use: "create MODEL", + Short: "Create a model from a Modelfile", + Args: cobra.MinimumNArgs(1), + RunE: create, + } + + createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")") + runCmd := &cobra.Command{ Use: "run MODEL [PROMPT]", Short: "Run a model", @@ -231,9 +288,26 @@ func NewCLI() *cobra.Command { RunE: RunServer, } + pullCmd := &cobra.Command{ + Use: "pull MODEL", + Short: "Pull a model from a registry", + Args: cobra.MinimumNArgs(1), + RunE: RunPull, + } + + pushCmd := &cobra.Command{ + Use: "push MODEL", + Short: "Push a model to a registry", + Args: cobra.MinimumNArgs(1), + RunE: push, + } + rootCmd.AddCommand( serveCmd, + createCmd, runCmd, + pullCmd, + pushCmd, ) return rootCmd diff --git a/parser/parser.go b/parser/parser.go new file mode 100644 index 00000000..48a1bc28 --- /dev/null +++ b/parser/parser.go @@ -0,0 +1,77 @@ +package parser + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +type Command struct { + Name string + Arg string +} + +func Parse(reader io.Reader) ([]Command, error) { + var commands []Command + var foundModel bool + + scanner := bufio.NewScanner(reader) + multiline := false + var multilineCommand *Command + for scanner.Scan() { + line := scanner.Text() + if multiline { + // If we're in a multiline string and the line is """, end the multiline string. + if strings.TrimSpace(line) == `"""` { + multiline = false + commands = append(commands, *multilineCommand) + } else { + // Otherwise, append the line to the multiline string. + multilineCommand.Arg += "\n" + line + } + continue + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + + command := Command{} + switch fields[0] { + case "FROM": + command.Name = "model" + command.Arg = fields[1] + if command.Arg == "" { + return nil, fmt.Errorf("no model specified in FROM line") + } + foundModel = true + case "PROMPT": + command.Name = "prompt" + if fields[1] == `"""` { + multiline = true + multilineCommand = &command + multilineCommand.Arg = "" + } else { + command.Arg = strings.Join(fields[1:], " ") + } + case "PARAMETER": + command.Name = fields[1] + command.Arg = strings.Join(fields[2:], " ") + default: + continue + } + if !multiline { + commands = append(commands, command) + } + } + + if !foundModel { + return nil, fmt.Errorf("no FROM line for the model was specified") + } + + if multiline { + return nil, fmt.Errorf("unclosed multiline string") + } + return commands, scanner.Err() +} diff --git a/server/images.go b/server/images.go new file mode 100644 index 00000000..83f1800f --- /dev/null +++ b/server/images.go @@ -0,0 +1,842 @@ +package server + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + + "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/parser" +) + +var DefaultRegistry string = "https://registry.ollama.ai" + +type Model struct { + Name string `json:"name"` + ModelPath string + Prompt string + Options api.Options +} + +type ManifestV2 struct { + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config Layer `json:"config"` + Layers []*Layer `json:"layers"` +} + +type Layer struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + Size int `json:"size"` +} + +type LayerWithBuffer struct { + Layer + + Buffer *bytes.Buffer +} + +type ConfigV2 struct { + Architecture string `json:"architecture"` + OS string `json:"os"` + RootFS RootFS `json:"rootfs"` +} + +type RootFS struct { + Type string `json:"type"` + DiffIDs []string `json:"diff_ids"` +} + +func GetManifest(name string) (*ManifestV2, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + fp := filepath.Join(home, ".ollama/models/manifests", name) + _, err = os.Stat(fp) + if os.IsNotExist(err) { + return nil, fmt.Errorf("couldn't find model '%s'", name) + } + + var manifest *ManifestV2 + + f, err := os.Open(fp) + if err != nil { + return nil, fmt.Errorf("couldn't open file '%s'", fp) + } + + decoder := json.NewDecoder(f) + err = decoder.Decode(&manifest) + if err != nil { + return nil, err + } + + return manifest, nil +} + +func GetModel(name string) (*Model, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + manifest, err := GetManifest(name) + if err != nil { + return nil, err + } + + model := &Model{ + Name: name, + } + + for _, layer := range manifest.Layers { + filename := filepath.Join(home, ".ollama/models/blobs", layer.Digest) + switch layer.MediaType { + case "application/vnd.ollama.image.model": + model.ModelPath = filename + case "application/vnd.ollama.image.prompt": + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + model.Prompt = string(data) + case "application/vnd.ollama.image.params": + /* + f, err = os.Open(filename) + if err != nil { + return nil, err + } + */ + + var opts api.Options + /* + decoder = json.NewDecoder(f) + err = decoder.Decode(&opts) + if err != nil { + return nil, err + } + */ + model.Options = opts + } + } + + return model, nil +} + +func getAbsPath(fn string) (string, error) { + if strings.HasPrefix(fn, "~/") { + home, err := os.UserHomeDir() + if err != nil { + log.Printf("error getting home directory: %v", err) + return "", err + } + fn = strings.Replace(fn, "~", home, 1) + } + + return filepath.Abs(fn) +} + +func CreateModel(name string, mf io.Reader, fn func(status string)) error { + fn("parsing modelfile") + commands, err := parser.Parse(mf) + if err != nil { + fn(fmt.Sprintf("error: %v", err)) + return err + } + + var layers []*LayerWithBuffer + param := make(map[string]string) + + for _, c := range commands { + log.Printf("[%s] - %s\n", c.Name, c.Arg) + switch c.Name { + case "model": + fn("looking for model") + mf, err := GetManifest(c.Arg) + if err != nil { + // if we couldn't read the manifest, try getting the bin file + fp, err := getAbsPath(c.Arg) + if err != nil { + fn("error determing path. exiting.") + return err + } + + fn("creating model layer") + file, err := os.Open(fp) + if err != nil { + fn(fmt.Sprintf("couldn't find model '%s'", c.Arg)) + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + l, err := CreateLayer(file) + if err != nil { + fn(fmt.Sprintf("couldn't create model layer: %v", err)) + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.model" + layers = append(layers, l) + } else { + log.Printf("manifest = %#v", mf) + for _, l := range mf.Layers { + newLayer, err := GetLayerWithBufferFromLayer(l) + if err != nil { + fn(fmt.Sprintf("couldn't read layer: %v", err)) + return err + } + layers = append(layers, newLayer) + } + } + case "prompt": + fn("creating prompt layer") + // remove the prompt layer if one exists + layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt") + + prompt := strings.NewReader(c.Arg) + l, err := CreateLayer(prompt) + if err != nil { + fn(fmt.Sprintf("couldn't create prompt layer: %v", err)) + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.prompt" + layers = append(layers, l) + default: + param[c.Name] = c.Arg + } + } + + // Create a single layer for the parameters + fn("creating parameter layer") + if len(param) > 0 { + layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") + paramData, err := paramsToReader(param) + if err != nil { + return fmt.Errorf("couldn't create params json: %v", err) + } + l, err := CreateLayer(paramData) + if err != nil { + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.params" + layers = append(layers, l) + } + + digests, err := getLayerDigests(layers) + if err != nil { + return err + } + + var manifestLayers []*Layer + for _, l := range layers { + manifestLayers = append(manifestLayers, &l.Layer) + } + + // Create a layer for the config object + fn("creating config layer") + cfg, err := createConfigLayer(digests) + if err != nil { + return err + } + layers = append(layers, cfg) + + err = SaveLayers(layers, fn, false) + if err != nil { + fn(fmt.Sprintf("error saving layers: %v", err)) + return err + } + + // Create the manifest + fn("writing manifest") + err = CreateManifest(name, cfg, manifestLayers) + if err != nil { + fn(fmt.Sprintf("error creating manifest: %v", err)) + return err + } + + fn("success") + return nil +} + +func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*LayerWithBuffer { + j := 0 + for _, l := range layers { + if l.MediaType != mediaType { + layers[j] = l + j++ + } + } + return layers[:j] +} + +func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error { + home, err := os.UserHomeDir() + if err != nil { + log.Printf("error getting home directory: %v", err) + return err + } + + dir := filepath.Join(home, ".ollama/models/blobs") + + err = os.MkdirAll(dir, 0o700) + if err != nil { + return fmt.Errorf("make blobs directory: %w", err) + } + + // Write each of the layers to disk + for _, layer := range layers { + fp := filepath.Join(dir, layer.Digest) + + _, err = os.Stat(fp) + if os.IsNotExist(err) || force { + fn(fmt.Sprintf("writing layer %s", layer.Digest)) + out, err := os.Create(fp) + if err != nil { + log.Printf("couldn't create %s", fp) + return err + } + defer out.Close() + + _, err = io.Copy(out, layer.Buffer) + if err != nil { + return err + } + } else { + fn(fmt.Sprintf("using already created layer %s", layer.Digest)) + } + } + + return nil +} + +func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error { + home, err := os.UserHomeDir() + if err != nil { + log.Printf("error getting home directory: %v", err) + return err + } + + manifest := ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: Layer{ + MediaType: cfg.MediaType, + Size: cfg.Size, + Digest: cfg.Digest, + }, + Layers: layers, + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + fp := filepath.Join(home, ".ollama/models/manifests", name) + err = os.WriteFile(fp, manifestJSON, 0644) + if err != nil { + log.Printf("couldn't write to %s", fp) + return err + } + return nil +} + +func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + fp := filepath.Join(home, ".ollama/models/blobs", layer.Digest) + file, err := os.Open(fp) + if err != nil { + return nil, fmt.Errorf("could not open blob: %w", err) + } + defer file.Close() + + newLayer, err := CreateLayer(file) + if err != nil { + return nil, err + } + newLayer.MediaType = layer.MediaType + return newLayer, nil +} + +func paramsToReader(m map[string]string) (io.Reader, error) { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return nil, err + } + + return strings.NewReader(string(data)), nil +} + +func getLayerDigests(layers []*LayerWithBuffer) ([]string, error) { + var digests []string + for _, l := range layers { + if l.Digest == "" { + return nil, fmt.Errorf("layer is missing a digest") + } + digests = append(digests, l.Digest) + } + return digests, nil +} + +// CreateLayer creates a Layer object from a given file +func CreateLayer(f io.Reader) (*LayerWithBuffer, error) { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, f) + if err != nil { + return nil, err + } + + digest, size := GetSHA256Digest(buf) + + layer := &LayerWithBuffer{ + Layer: Layer{ + MediaType: "application/vnd.docker.image.rootfs.diff.tar", + Digest: digest, + Size: size, + }, + Buffer: buf, + } + + return layer, nil +} + +func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + fn("retrieving manifest", "", 0, 0, 0) + manifest, err := GetManifest(name) + if err != nil { + fn("couldn't retrieve manifest", "", 0, 0, 0) + return err + } + + var repoName string + var tag string + + comps := strings.Split(name, ":") + switch { + case len(comps) < 1 || len(comps) > 2: + return fmt.Errorf("repository name was invalid") + case len(comps) == 1: + repoName = comps[0] + tag = "latest" + case len(comps) == 2: + repoName = comps[0] + tag = comps[1] + } + + var layers []*Layer + var total int + var completed int + for _, layer := range manifest.Layers { + layers = append(layers, layer) + total += layer.Size + } + layers = append(layers, &manifest.Config) + total += manifest.Config.Size + + for _, layer := range layers { + exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password) + if err != nil { + return err + } + + if exists { + completed += layer.Size + fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total)) + continue + } + + fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total)) + + location, err := startUpload(DefaultRegistry, repoName, username, password) + if err != nil { + log.Printf("couldn't start upload: %v", err) + return err + } + + err = uploadBlob(location, layer, username, password) + if err != nil { + log.Printf("error uploading blob: %v", err) + return err + } + completed += layer.Size + fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total)) + } + + fn("pushing manifest", "", total, completed, float64(completed/total)) + url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag) + headers := map[string]string{ + "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password) + if err != nil { + return err + } + defer resp.Body.Close() + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + fn("success", "", total, completed, 1.0) + + return nil +} + +func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + var repoName string + var tag string + + comps := strings.Split(name, ":") + switch { + case len(comps) < 1 || len(comps) > 2: + return fmt.Errorf("repository name was invalid") + case len(comps) == 1: + repoName = comps[0] + tag = "latest" + case len(comps) == 2: + repoName = comps[0] + tag = comps[1] + } + + fn("pulling manifest", "", 0, 0, 0) + + manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password) + if err != nil { + return fmt.Errorf("pull model manifest: %q", err) + } + + log.Printf("manifest = %#v", manifest) + + var layers []*Layer + var total int + var completed int + for _, layer := range manifest.Layers { + layers = append(layers, layer) + total += layer.Size + } + layers = append(layers, &manifest.Config) + total += manifest.Config.Size + + for _, layer := range layers { + fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total)) + if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil { + fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0) + return err + } + completed += layer.Size + fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total)) + } + + fn("writing manifest", "", total, completed, 1.0) + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + fp := filepath.Join(home, ".ollama/models/manifests", name) + + err = os.MkdirAll(path.Dir(fp), 0o700) + if err != nil { + return fmt.Errorf("make manifests directory: %w", err) + } + + err = os.WriteFile(fp, manifestJSON, 0644) + if err != nil { + log.Printf("couldn't write to %s", fp) + return err + } + + fn("success", "", total, completed, 1.0) + + return nil +} + +func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) { + url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag) + headers := map[string]string{ + "Accept": "application/vnd.docker.distribution.manifest.v2+json", + } + + resp, err := makeRequest("GET", url, headers, nil, username, password) + if err != nil { + log.Printf("couldn't get manifest: %v", err) + return nil, err + } + defer resp.Body.Close() + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + var m *ManifestV2 + if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { + return nil, err + } + + return m, err +} + +func createConfigLayer(layers []string) (*LayerWithBuffer, error) { + // TODO change architecture and OS + config := ConfigV2{ + Architecture: "arm64", + OS: "linux", + RootFS: RootFS{ + Type: "layers", + DiffIDs: layers, + }, + } + + configJSON, err := json.Marshal(config) + if err != nil { + return nil, err + } + + buf := bytes.NewBuffer(configJSON) + digest, size := GetSHA256Digest(buf) + + layer := &LayerWithBuffer{ + Layer: Layer{ + MediaType: "application/vnd.docker.container.image.v1+json", + Digest: digest, + Size: size, + }, + Buffer: buf, + } + return layer, nil +} + +// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer +func GetSHA256Digest(data *bytes.Buffer) (string, int) { + layerBytes := data.Bytes() + hash := sha256.Sum256(layerBytes) + return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes) +} + +func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName) + + resp, err := makeRequest("POST", url, nil, nil, username, password) + if err != nil { + log.Printf("couldn't start upload: %v", err) + return "", err + } + defer resp.Body.Close() + + // Check for success + if resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + // Extract UUID location from header + location := resp.Header.Get("Location") + if location == "" { + return "", fmt.Errorf("location header is missing in response") + } + + return location, nil +} + +// Function to check if a blob already exists in the Docker registry +func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest) + + resp, err := makeRequest("HEAD", url, nil, nil, username, password) + if err != nil { + log.Printf("couldn't check for blob: %v", err) + return false, err + } + defer resp.Body.Close() + + // Check for success: If the blob exists, the Docker registry will respond with a 200 OK + return resp.StatusCode == http.StatusOK, nil +} + +func uploadBlob(location string, layer *Layer, username string, password string) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create URL + url := fmt.Sprintf("%s&digest=%s", location, layer.Digest) + + headers := make(map[string]string) + headers["Content-Length"] = fmt.Sprintf("%d", layer.Size) + headers["Content-Type"] = "application/octet-stream" + + // TODO change from monolithic uploads to chunked uploads + // TODO allow resumability + // TODO allow canceling uploads via DELETE + // TODO allow cross repo blob mount + + fp := filepath.Join(home, ".ollama/models/blobs", layer.Digest) + f, err := os.Open(fp) + if err != nil { + return err + } + + resp, err := makeRequest("PUT", url, headers, f, username, password) + if err != nil { + log.Printf("couldn't upload blob: %v", err) + return err + } + defer resp.Body.Close() + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + return nil +} + +func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + fp := filepath.Join(home, ".ollama/models/blobs", digest) + + _, err = os.Stat(fp) + if !os.IsNotExist(err) { + // we already have the file, so return + log.Printf("already have %s\n", digest) + return nil + } + + var size int64 + + fi, err := os.Stat(fp + "-partial") + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + } + + url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest) + headers := map[string]string{ + "Range": fmt.Sprintf("bytes=%d-", size), + } + + resp, err := makeRequest("GET", url, headers, nil, username, password) + if err != nil { + log.Printf("couldn't download blob: %v", err) + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + body, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + err = os.MkdirAll(path.Dir(fp), 0o700) + if err != nil { + return fmt.Errorf("make blobs directory: %w", err) + } + + out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + panic(err) + } + defer out.Close() + + remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + completed := size + total := remaining + completed + + for { + fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total)) + if completed >= total { + fmt.Printf("finished downloading\n") + err = os.Rename(fp+"-partial", fp) + if err != nil { + fmt.Printf("error: %v\n", err) + fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1) + return err + } + break + } + + n, err := io.CopyN(out, resp.Body, 8192) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + completed += n + } + + log.Printf("success getting %s\n", digest) + return nil +} + +func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + // TODO: better auth + if username != "" && password != "" { + req.SetBasicAuth(username, password) + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + log.Printf("redirected to: %s\n", req.URL) + return nil + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + return resp, nil +} diff --git a/server/models.go b/server/models.go deleted file mode 100644 index de46e96f..00000000 --- a/server/models.go +++ /dev/null @@ -1,128 +0,0 @@ -package server - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" -) - -const directoryURL = "https://ollama.ai/api/models" - -type Model struct { - Name string `json:"name"` - DisplayName string `json:"display_name"` - Parameters string `json:"parameters"` - URL string `json:"url"` - ShortDescription string `json:"short_description"` - Description string `json:"description"` - PublishedBy string `json:"published_by"` - OriginalAuthor string `json:"original_author"` - OriginalURL string `json:"original_url"` - License string `json:"license"` -} - -func (m *Model) FullName() string { - home, err := os.UserHomeDir() - if err != nil { - panic(err) - } - - return filepath.Join(home, ".ollama", "models", m.Name+".bin") -} - -func (m *Model) TempFile() string { - fullName := m.FullName() - return filepath.Join( - filepath.Dir(fullName), - fmt.Sprintf(".%s.part", filepath.Base(fullName)), - ) -} - -func getRemote(model string) (*Model, error) { - // resolve the model download from our directory - resp, err := http.Get(directoryURL) - if err != nil { - return nil, fmt.Errorf("failed to get directory: %w", err) - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read directory: %w", err) - } - var models []Model - err = json.Unmarshal(body, &models) - if err != nil { - return nil, fmt.Errorf("failed to parse directory: %w", err) - } - for _, m := range models { - if m.Name == model { - return &m, nil - } - } - return nil, fmt.Errorf("model not found in directory: %s", model) -} - -func saveModel(model *Model, fn func(total, completed int64)) error { - // this models cache directory is created by the server on startup - - client := &http.Client{} - req, err := http.NewRequest("GET", model.URL, nil) - if err != nil { - return fmt.Errorf("failed to download model: %w", err) - } - - var size int64 - - // completed file doesn't exist, check partial file - fi, err := os.Stat(model.TempFile()) - switch { - case errors.Is(err, os.ErrNotExist): - // noop, file doesn't exist so create it - case err != nil: - return fmt.Errorf("stat: %w", err) - default: - size = fi.Size() - } - - req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size)) - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to download model: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - return fmt.Errorf("failed to download model: %s", resp.Status) - } - - out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - panic(err) - } - defer out.Close() - - remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - completed := size - - total := remaining + completed - - for { - fn(total, completed) - if completed >= total { - return os.Rename(model.TempFile(), model.FullName()) - } - - n, err := io.CopyN(out, resp.Body, 8192) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - completed += n - } -} diff --git a/server/routes.go b/server/routes.go index 36c182e3..fea96df8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,12 +1,10 @@ package server import ( - "embed" "encoding/json" - "errors" + "fmt" "io" "log" - "math" "net" "net/http" "os" @@ -16,16 +14,11 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/lithammer/fuzzysearch/fuzzy" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" ) -//go:embed templates/* -var templatesFS embed.FS -var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) - func cacheDir() string { home, err := os.UserHomeDir() if err != nil { @@ -40,6 +33,7 @@ func generate(c *gin.Context) { req := api.GenerateRequest{ Options: api.DefaultOptions(), + Prompt: "", } if err := c.ShouldBindJSON(&req); err != nil { @@ -47,34 +41,28 @@ func generate(c *gin.Context) { return } - if remoteModel, _ := getRemote(req.Model); remoteModel != nil { - req.Model = remoteModel.FullName() - } - if _, err := os.Stat(req.Model); err != nil { - if !errors.Is(err, os.ErrNotExist) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.Model = filepath.Join(cacheDir(), "models", req.Model+".bin") + model, err := GetModel(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } - templateNames := make([]string, 0, len(templates.Templates())) - for _, template := range templates.Templates() { - templateNames = append(templateNames, template.Name()) + templ, err := template.New("").Parse(model.Prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - match, _ := matchRankOne(filepath.Base(req.Model), templateNames) - if template := templates.Lookup(match); template != nil { - var sb strings.Builder - if err := template.Execute(&sb, req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - req.Prompt = sb.String() + var sb strings.Builder + if err = templ.Execute(&sb, req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + req.Prompt = sb.String() - llm, err := llama.New(req.Model, req.Options) + fmt.Printf("prompt = >>>%s<<<\n", req.Prompt) + + llm, err := llama.New(model.ModelPath, req.Options) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -105,40 +93,84 @@ func pull(c *gin.Context) { return } - remote, err := getRemote(req.Model) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } + ch := make(chan any) + go func() { + defer close(ch) + fn := func(status, digest string, total, completed int, percent float64) { + ch <- api.PullProgress{ + Status: status, + Digest: digest, + Total: total, + Completed: completed, + Percent: percent, + } + } + if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + }() - // check if completed file exists - fi, err := os.Stat(remote.FullName()) - switch { - case errors.Is(err, os.ErrNotExist): - // noop, file doesn't exist so create it - case err != nil: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - default: - c.JSON(http.StatusOK, api.PullProgress{ - Total: fi.Size(), - Completed: fi.Size(), - Percent: 100, - }) + streamResponse(c, ch) +} +func push(c *gin.Context) { + var req api.PushRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } ch := make(chan any) go func() { defer close(ch) - saveModel(remote, func(total, completed int64) { - ch <- api.PullProgress{ + fn := func(status, digest string, total, completed int, percent float64) { + ch <- api.PushProgress{ + Status: status, + Digest: digest, Total: total, Completed: completed, - Percent: float64(completed) / float64(total) * 100, + Percent: percent, } - }) + } + if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + }() + + streamResponse(c, ch) +} + +func create(c *gin.Context) { + var req api.CreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + + // NOTE consider passing the entire Modelfile in the json instead of the path to it + + file, err := os.Open(req.Path) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + defer file.Close() + + ch := make(chan any) + go func() { + defer close(ch) + fn := func(status string) { + ch <- api.CreateProgress{ + Status: status, + } + } + + if err := CreateModel(req.Name, file, fn); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } }() streamResponse(c, ch) @@ -153,6 +185,8 @@ func Serve(ln net.Listener) error { r.POST("/api/pull", pull) r.POST("/api/generate", generate) + r.POST("/api/create", create) + r.POST("/api/push", push) log.Printf("Listening on %s", ln.Addr()) s := &http.Server{ @@ -162,18 +196,6 @@ func Serve(ln net.Listener) error { return s.Serve(ln) } -func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) { - bestRank = math.MaxInt - for _, target := range targets { - if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank { - bestRank = rank - bestMatch = target - } - } - - return -} - func streamResponse(c *gin.Context, ch chan any) { c.Stream(func(w io.Writer) bool { val, ok := <-ch