diff --git a/server/upload.go b/server/upload.go index a977581f..53d8de37 100644 --- a/server/upload.go +++ b/server/upload.go @@ -57,6 +57,12 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r // 95MB chunk size chunkSize := 95 * 1024 * 1024 + pw := ProgressWriter{ + status: fmt.Sprintf("uploading %s", layer.Digest), + digest: layer.Digest, + total: layer.Size, + fn: fn, + } for offset := int64(0); offset < int64(layer.Size); { chunk := int64(layer.Size) - offset @@ -64,88 +70,23 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r chunk = int64(chunkSize) } - sectionReader := io.NewSectionReader(f, int64(offset), chunk) - for try := 0; try < MaxRetries; try++ { - ch := make(chan error, 1) - - r, w := io.Pipe() - defer r.Close() - go func() { - defer w.Close() - - for chunked := int64(0); chunked < chunk; { - select { - case err := <-ch: - log.Printf("chunk interrupted: %v", err) - return - default: - n, err := io.CopyN(w, sectionReader, 1024*1024) - if err != nil && !errors.Is(err, io.EOF) { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("error reading chunk: %v", err), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset), - }) - - return - } - - chunked += n - fn(api.ProgressResponse{ - Status: fmt.Sprintf("uploading %s", layer.Digest), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset) + int(chunked), - }) - } - } - }() - - headers := make(http.Header) - headers.Set("Content-Type", "application/octet-stream") - headers.Set("Content-Length", strconv.Itoa(int(chunk))) - headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1)) - resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts) - if err != nil && !errors.Is(err, io.EOF) { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("error uploading chunk: %v", err), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset), - }) - - return err - } - defer resp.Body.Close() - - switch { - case resp.StatusCode == http.StatusUnauthorized: - ch <- errors.New("unauthorized") - - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) - if err != nil { - return err - } - - regOpts.Token = token - sectionReader = io.NewSectionReader(f, int64(offset), chunk) - continue - case resp.StatusCode >= http.StatusBadRequest: - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) - } - - offset += sectionReader.Size() - requestURL, err = url.Parse(resp.Header.Get("Location")) - if err != nil { - return err - } - - break + resp, err := uploadBlobChunk(ctx, requestURL, f, offset, chunk, regOpts, &pw) + if err != nil { + fn(api.ProgressResponse{ + Status: fmt.Sprintf("error uploading limit: %v", err), + Digest: layer.Digest, + Total: layer.Size, + Completed: int(offset), + }) } + + offset += chunk + location, err := resp.Location() + if err != nil { + return err + } + + requestURL = location } values := requestURL.Query() @@ -170,3 +111,72 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r } return nil } + +func uploadBlobChunk(ctx context.Context, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) { + sectionReader := io.NewSectionReader(r, int64(offset), limit) + + headers := make(http.Header) + headers.Set("Content-Type", "application/octet-stream") + headers.Set("Content-Length", strconv.Itoa(int(limit))) + headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1)) + + for try := 0; try < MaxRetries; try++ { + resp, err := makeRequest(ctx, "PATCH", requestURL, headers, io.TeeReader(sectionReader, pw), opts) + if err != nil && !errors.Is(err, io.EOF) { + return nil, err + } + defer resp.Body.Close() + + switch { + case resp.StatusCode == http.StatusUnauthorized: + auth := resp.Header.Get("www-authenticate") + authRedir := ParseAuthRedirectString(auth) + token, err := getAuthToken(ctx, authRedir) + if err != nil { + return nil, err + } + + opts.Token = token + + pw.completed = int(offset) + sectionReader = io.NewSectionReader(r, offset, limit) + continue + case resp.StatusCode >= http.StatusBadRequest: + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) + } + + return resp, nil + } + + return nil, fmt.Errorf("max retries exceeded") +} + +type ProgressWriter struct { + status string + digest string + bucket int + completed int + total int + fn func(api.ProgressResponse) +} + +func (pw *ProgressWriter) Write(b []byte) (int, error) { + n := len(b) + pw.bucket += n + pw.completed += n + + // throttle status updates to not spam the client + if pw.bucket >= 1024*1024 || pw.completed >= pw.total { + pw.fn(api.ProgressResponse{ + Status: pw.status, + Digest: pw.digest, + Total: pw.total, + Completed: pw.completed, + }) + + pw.bucket = 0 + } + + return n, nil +}