From be989d89d1345384c21a2cb4eee0ccdc16f1df5b Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 10 Aug 2023 11:34:25 -0700 Subject: [PATCH] Token auth (#314) --- api/types.go | 4 ++ server/auth.go | 164 +++++++++++++++++++++++++++++++++++++++++++++++ server/images.go | 71 ++++++++++++++++++-- 3 files changed, 233 insertions(+), 6 deletions(-) create mode 100644 server/auth.go diff --git a/api/types.go b/api/types.go index 825db36e..0441a799 100644 --- a/api/types.go +++ b/api/types.go @@ -98,6 +98,10 @@ type ListResponseModel struct { Size int `json:"size"` } +type TokenResponse struct { + Token string `json:"token"` +} + type GenerateResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 00000000..d7803a2b --- /dev/null +++ b/server/auth.go @@ -0,0 +1,164 @@ +package server + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "path" + "strings" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/jmorganca/ollama/api" +) + +type AuthRedirect struct { + Realm string + Service string + Scope string +} + +type SignatureData struct { + Method string + Path string + Data []byte +} + +func generateNonce(length int) (string, error) { + nonce := make([]byte, length) + _, err := rand.Read(nonce) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(nonce), nil +} + +func (r AuthRedirect) URL() (string, error) { + nonce, err := generateNonce(16) + if err != nil { + return "", err + } + return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil +} + +func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, error) { + url, err := redirData.URL() + if err != nil { + return "", err + } + + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + keyPath := path.Join(home, ".ollama/id_ed25519") + + rawKey, err := ioutil.ReadFile(keyPath) + if err != nil { + log.Printf("Failed to load private key: %v", err) + return "", err + } + + s := SignatureData{ + Method: "GET", + Path: url, + Data: nil, + } + + if !strings.HasPrefix(s.Path, "http") { + if regOpts.Insecure { + s.Path = "http://" + url + } else { + s.Path = "https://" + url + } + } + + sig, err := s.Sign(rawKey) + if err != nil { + return "", err + } + + headers := map[string]string{ + "Authorization": sig, + } + + resp, err := makeRequest("GET", url, headers, nil, regOpts) + if err != nil { + log.Printf("couldn't get token: %q", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var tok api.TokenResponse + if err := json.Unmarshal(respBody, &tok); err != nil { + return "", err + } + + return tok.Token, nil +} + +// Bytes returns a byte slice of the data to sign for the request +func (s SignatureData) Bytes() []byte { + // We first derive the content hash of the request body using: + // base64(hex(sha256(request body))) + + hash := sha256.Sum256(s.Data) + hashHex := make([]byte, hex.EncodedLen(len(hash))) + hex.Encode(hashHex, hash[:]) + contentHash := base64.StdEncoding.EncodeToString(hashHex) + + // We then put the entire request together in a serialize string using: + // ",," + // e.g. "GET,http://localhost,OTdkZjM1O..." + + return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ",")) +} + +// SignData takes a SignatureData object and signs it with a raw private key +func (s SignatureData) Sign(rawKey []byte) (string, error) { + privateKey, err := ssh.ParseRawPrivateKey(rawKey) + if err != nil { + return "", err + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + return "", err + } + + // get the pubkey, but remove the type + pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey()) + parts := bytes.Split(pubKey, []byte(" ")) + if len(parts) < 2 { + return "", fmt.Errorf("malformed public key") + } + + signedData, err := signer.Sign(nil, s.Bytes()) + if err != nil { + return "", err + } + + // signature is : + sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)) + return sig, nil +} diff --git a/server/images.go b/server/images.go index 2ec24854..8a01053e 100644 --- a/server/images.go +++ b/server/images.go @@ -28,6 +28,7 @@ type RegistryOptions struct { Insecure bool Username string Password string + Token string } type Model struct { @@ -1129,18 +1130,30 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader, } } - req, err := http.NewRequest(method, url, body) + // make a copy of the body in case we need to try the call to makeRequest again + var buf bytes.Buffer + if body != nil { + _, err := io.Copy(&buf, body) + if err != nil { + return nil, err + } + } + + bodyCopy := bytes.NewReader(buf.Bytes()) + + req, err := http.NewRequest(method, url, bodyCopy) if err != nil { return nil, err } - for k, v := range headers { - req.Header.Set(k, v) + if regOpts.Token != "" { + req.Header.Set("Authorization", "Bearer "+regOpts.Token) + } else if regOpts.Username != "" && regOpts.Password != "" { + req.SetBasicAuth(regOpts.Username, regOpts.Password) } - // TODO: better auth - if regOpts.Username != "" && regOpts.Password != "" { - req.SetBasicAuth(regOpts.Username, regOpts.Password) + for k, v := range headers { + req.Header.Set(k, v) } client := &http.Client{ @@ -1157,9 +1170,55 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader, return nil, err } + // if the request is unauthenticated, try to authenticate and make the request again + if resp.StatusCode == http.StatusUnauthorized { + auth := resp.Header.Get("Www-Authenticate") + authRedir := ParseAuthRedirectString(string(auth)) + token, err := getAuthToken(authRedir, regOpts) + if err != nil { + return nil, err + } + regOpts.Token = token + bodyCopy = bytes.NewReader(buf.Bytes()) + return makeRequest(method, url, headers, bodyCopy, regOpts) + } + return resp, nil } +func getValue(header, key string) string { + startIdx := strings.Index(header, key+"=") + if startIdx == -1 { + return "" + } + + // Move the index to the starting quote after the key. + startIdx += len(key) + 2 + endIdx := startIdx + + for endIdx < len(header) { + if header[endIdx] == '"' { + if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue + endIdx++ + continue + } + break + } + endIdx++ + } + return header[startIdx:endIdx] +} + +func ParseAuthRedirectString(authStr string) AuthRedirect { + authStr = strings.TrimPrefix(authStr, "Bearer ") + + return AuthRedirect{ + Realm: getValue(authStr, "realm"), + Service: getValue(authStr, "service"), + Scope: getValue(authStr, "scope"), + } +} + var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again") func verifyBlob(digest string) error {