Move hub auth out to new package

This commit is contained in:
Daniel Hiltgen 2024-02-05 12:59:52 -08:00 committed by jmorganca
parent 9da9e8fb72
commit f397e0e988
6 changed files with 142 additions and 115 deletions

View file

@ -1,4 +1,4 @@
package server package auth
import ( import (
"bytes" "bytes"
@ -24,6 +24,10 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
) )
const (
KeyType = "id_ed25519"
)
type AuthRedirect struct { type AuthRedirect struct {
Realm string Realm string
Service string Service string
@ -71,39 +75,47 @@ func (r AuthRedirect) URL() (*url.URL, error) {
return redirectURL, nil return redirectURL, nil
} }
func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) { func SignRequest(method, url string, data []byte, headers http.Header) error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
keyPath := filepath.Join(home, ".ollama", KeyType)
rawKey, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return err
}
s := SignatureData{
Method: method,
Path: url,
Data: data,
}
sig, err := s.Sign(rawKey)
if err != nil {
return err
}
headers.Set("Authorization", sig)
return nil
}
func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
redirectURL, err := redirData.URL() redirectURL, err := redirData.URL()
if err != nil { if err != nil {
return "", err return "", err
} }
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
rawKey, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
s := SignatureData{
Method: http.MethodGet,
Path: redirectURL.String(),
Data: nil,
}
sig, err := s.Sign(rawKey)
if err != nil {
return "", err
}
headers := make(http.Header) headers := make(http.Header)
headers.Set("Authorization", sig) err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) if err != nil {
return "", err
}
resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("couldn't get token: %q", err)) slog.Info(fmt.Sprintf("couldn't get token: %q", err))
return "", err return "", err

72
auth/request.go Normal file
View file

@ -0,0 +1,72 @@
package auth
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"runtime"
"strconv"
"github.com/jmorganca/ollama/version"
)
type RegistryOptions struct {
Insecure bool
Username string
Password string
Token string
}
func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}

View file

@ -22,6 +22,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
) )
@ -85,7 +86,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
return n, nil return n, nil
} }
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*") partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil { if err != nil {
return err return err
@ -137,11 +138,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
return nil return nil
} }
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) { func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
b.err = b.run(ctx, requestURL, opts) b.err = b.run(ctx, requestURL, opts)
} }
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
@ -210,7 +211,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
return nil return nil
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
g.Go(func() error { g.Go(func() error {
headers := make(http.Header) headers := make(http.Header)
@ -334,7 +335,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
type downloadOpts struct { type downloadOpts struct {
mp ModelPath mp ModelPath
digest string digest string
regOpts *RegistryOptions regOpts *auth.RegistryOptions
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }

View file

@ -16,25 +16,17 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"text/template" "text/template"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
) )
type RegistryOptions struct {
Insecure bool
Username string
Password string
Token string
}
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
Config ConfigV2 Config ConfigV2
@ -320,7 +312,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"}) fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
return err return err
} }
@ -840,7 +832,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
return buf.String(), nil return buf.String(), nil
} }
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
@ -890,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil return nil
} }
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
var manifest *ManifestV2 var manifest *ManifestV2
@ -996,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil return nil
} }
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header) headers := make(http.Header)
@ -1028,9 +1020,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
var errUnauthorized = fmt.Errorf("unauthorized") var errUnauthorized = fmt.Errorf("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil { if err != nil {
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err)) slog.Info(fmt.Sprintf("request failed: %v", err))
@ -1042,9 +1034,9 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
switch { switch {
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry // Handle authentication error with one retry
auth := resp.Header.Get("www-authenticate") authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(authenticate)
token, err := getAuthToken(ctx, authRedir) token, err := auth.GetAuthToken(ctx, authRedir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1071,58 +1063,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, errUnauthorized return nil, errUnauthorized
} }
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func getValue(header, key string) string { func getValue(header, key string) string {
startIdx := strings.Index(header, key+"=") startIdx := strings.Index(header, key+"=")
if startIdx == -1 { if startIdx == -1 {
@ -1146,10 +1086,10 @@ func getValue(header, key string) string {
return header[startIdx:endIdx] return header[startIdx:endIdx]
} }
func ParseAuthRedirectString(authStr string) AuthRedirect { func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
authStr = strings.TrimPrefix(authStr, "Bearer ") authStr = strings.TrimPrefix(authStr, "Bearer ")
return AuthRedirect{ return auth.AuthRedirect{
Realm: getValue(authStr, "realm"), Realm: getValue(authStr, "realm"),
Service: getValue(authStr, "service"), Service: getValue(authStr, "service"),
Scope: getValue(authStr, "scope"), Scope: getValue(authStr, "scope"),

View file

@ -25,6 +25,7 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/gpu" "github.com/jmorganca/ollama/gpu"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/openai" "github.com/jmorganca/ollama/openai"
@ -479,7 +480,7 @@ func PullModelHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &RegistryOptions{ regOpts := &auth.RegistryOptions{
Insecure: req.Insecure, Insecure: req.Insecure,
} }
@ -528,7 +529,7 @@ func PushModelHandler(c *gin.Context) {
ch <- r ch <- r
} }
regOpts := &RegistryOptions{ regOpts := &auth.RegistryOptions{
Insecure: req.Insecure, Insecure: req.Insecure,
} }

View file

@ -18,6 +18,7 @@ import (
"time" "time"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -49,7 +50,7 @@ const (
maxUploadPartSize int64 = 1000 * format.MegaByte maxUploadPartSize int64 = 1000 * format.MegaByte
) )
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
p, err := GetBlobsPath(b.Digest) p, err := GetBlobsPath(b.Digest)
if err != nil { if err != nil {
return err return err
@ -121,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error. // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
defer blobUploadManager.Delete(b.Digest) defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
@ -212,7 +213,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
b.done = true b.done = true
} }
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@ -227,7 +228,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
md5sum := md5.New() md5sum := md5.New()
w := &progressWriter{blobUpload: b} w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts) resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil { if err != nil {
w.Rollback() w.Rollback()
return err return err
@ -277,9 +278,9 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback() w.Rollback()
auth := resp.Header.Get("www-authenticate") authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(authenticate)
token, err := getAuthToken(ctx, authRedir) token, err := auth.GetAuthToken(ctx, authRedir)
if err != nil { if err != nil {
return err return err
} }
@ -364,7 +365,7 @@ func (p *progressWriter) Rollback() {
p.written = 0 p.written = 0
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL() requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)