Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions pkg/inference/backends/llamacpp/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package llamacpp

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -50,7 +47,7 @@ func SetDesiredServerVersion(version string) {
}

//nolint:unused // Used in platform-specific files (download_darwin.go, download_windows.go)
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger,
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
) error {
ShouldUpdateServerLock.Lock()
Expand All @@ -63,35 +60,18 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge

log.Info("downloadLatestLlamaCpp", "desiredVersion", desiredVersion, "desiredVariant", desiredVariant, "vendoredServerStoragePath", vendoredServerStoragePath, "llamaCppPath", llamaCppPath)
desiredTag := desiredVersion + "-" + desiredVariant
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags/%s", hubNamespace, hubRepo, desiredTag)
resp, err := httpClient.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
// Resolve the desired tag to a digest via the Registry HTTP API v2. This
// honors l.registryMirrors (typically a corporate Artifactory / Nexus /
// Harbor mirror configured for docker.io) and credentials populated by
// `docker login`, so customers behind a private mirror with no direct
// egress to registry-1.docker.io can still resolve and pull the backend
// image. See docker/model-runner#TBD.
tagRef := fmt.Sprintf("registry-1.docker.io/%s/%s:%s", hubNamespace, hubRepo, desiredTag)
latest, err := dockerhub.ResolveDigest(ctx, tagRef, l.registryMirrors)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}

// https://docs.docker.com/reference/api/hub/latest/#tag/repositories/paths/~1v2~1namespaces~1%7Bnamespace%7D~1repositories~1%7Brepository%7D~1tags~1%7Btag%7D/get
var response struct {
Name string `json:"name"`
Digest string `json:"digest"`
}

if unmarshalErr := json.Unmarshal(body, &response); unmarshalErr != nil {
return fmt.Errorf("failed to unmarshal response body: %w", unmarshalErr)
}

var latest string
if response.Name == desiredTag {
latest = response.Digest
}
if latest == "" {
log.Warn("could not find the tag", "tag", desiredTag, "response", body)
return fmt.Errorf("could not find the %s tag", desiredTag)
log.Warn("could not resolve llama.cpp tag", "tag", desiredTag, "mirrors", l.registryMirrors, "error", err)
return fmt.Errorf("could not resolve the %s tag: %w", desiredTag, err)
}

bundledVersionFile := filepath.Join(vendoredServerStoragePath, "com.docker.llama-server.digest")
Expand Down
4 changes: 2 additions & 2 deletions pkg/inference/backends/llamacpp/download_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"github.com/docker/model-runner/pkg/logging"
)

func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, _ *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
desiredVersion := GetDesiredServerVersion()
desiredVariant := "metal"
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
return l.downloadLatestLlamaCpp(ctx, log, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}
4 changes: 2 additions & 2 deletions pkg/inference/backends/llamacpp/download_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/docker/model-runner/pkg/logging"
)

func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, _ *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
Expand Down Expand Up @@ -43,6 +43,6 @@ func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger,
desiredVariant = "opencl"
}
l.status = inference.FormatInstalling(fmt.Sprintf("%s llama.cpp %s", inference.DetailCheckingForUpdates, desiredVariant))
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
return l.downloadLatestLlamaCpp(ctx, log, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}
52 changes: 48 additions & 4 deletions pkg/internal/dockerhub/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/containerd/containerd/v2/core/remotes"
"github.com/containerd/containerd/v2/core/remotes/docker"
"github.com/containerd/containerd/v2/plugins/content/local"
"github.com/containerd/errdefs"
"github.com/containerd/platforms"
"github.com/docker/model-runner/pkg/internal/jsonutil"
"github.com/docker/model-runner/pkg/internal/registryutil"
Expand Down Expand Up @@ -47,6 +48,39 @@ func PullPlatform(ctx context.Context, image, destination, requiredOs, requiredA
return archive.Export(ctx, store, output, archive.WithManifest(*desc, image), archive.WithSkipMissing(store))
}

// ResolveDigest resolves the given image reference (e.g. "registry-1.docker.io/docker/foo:tag")
// against the registry (with optional mirrors tried first for Docker Hub references) and
// returns the resolved digest. It does not download any blobs; it issues only the manifest
// HEAD/GET that the registry resolver needs.
//
// Authentication uses the same credentials lookup as PullPlatform (env vars
// DOCKER_HUB_USER/DOCKER_HUB_PASSWORD or ~/.docker/config.json), so a prior
// `docker login <mirror-host>` is honored.
func ResolveDigest(ctx context.Context, ref string, mirrors []string) (string, error) {
resolver := newResolver(mirrors)
desc, err := retry(ctx, 10, 1*time.Second, func() (*v1.Descriptor, error) {
name, d, err := resolver.Resolve(ctx, ref)
if err != nil {
return nil, err
}
slog.Debug("resolved image tag", "ref", ref, "resolved", name, "digest", d.Digest.String())
return &d, nil
})
if err != nil {
return "", fmt.Errorf("resolving image %q: %w", ref, err)
}
return desc.Digest.String(), nil
}

// newResolver builds a containerd docker resolver that authenticates via
// dockerCredentials and tries the given mirrors before the upstream registry.
func newResolver(mirrors []string) remotes.Resolver {
authorizer := docker.NewDockerAuthorizer(docker.WithAuthCreds(dockerCredentials))
return docker.NewResolver(docker.ResolverOptions{
Hosts: registryutil.RegistryHosts(mirrors, authorizer, nil),
})
}

func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1.Descriptor, error)) (*v1.Descriptor, error) {
var err error
var result *v1.Descriptor
Expand All @@ -63,15 +97,25 @@ func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1
if err == nil {
return result, nil
}
if isTerminal(err) {
return nil, err
}
}
return nil, fmt.Errorf("after %d attempts, last error: %w", attempts, err)
}

// isTerminal reports whether err is non-retryable: a missing tag/manifest, an
// authentication failure, or a canceled/expired context. Retrying these only
// wastes time, so the caller should fail fast instead of looping.
func isTerminal(err error) bool {
return errdefs.IsNotFound(err) ||
errdefs.IsUnauthorized(err) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
}
Comment thread
ilopezluna marked this conversation as resolved.

func fetch(ctx context.Context, store content.Store, ref, requiredOs, requiredArch string, mirrors []string) (*v1.Descriptor, error) {
authorizer := docker.NewDockerAuthorizer(docker.WithAuthCreds(dockerCredentials))
resolver := docker.NewResolver(docker.ResolverOptions{
Hosts: registryutil.RegistryHosts(mirrors, authorizer, nil),
})
resolver := newResolver(mirrors)
Comment thread
ilopezluna marked this conversation as resolved.
Outdated
name, desc, err := resolver.Resolve(ctx, ref)
if err != nil {
return nil, err
Expand Down
139 changes: 139 additions & 0 deletions pkg/internal/dockerhub/download_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package dockerhub

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/containerd/errdefs"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
)

// registryHandler is a minimal Docker Registry v2 HTTP handler that supports
// the manifest HEAD / GET requests issued by containerd's docker resolver.
type registryHandler struct {
// tag is the tag to recognize; for any other tag the handler returns 404.
tag string
// digest returned in the Docker-Content-Digest header.
digest string
// requests counts how many requests this handler received (for assertions).
requests atomic.Int64
}

func (h *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.requests.Add(1)
switch {
case r.URL.Path == "/v2/" || r.URL.Path == "/v2":
// API version probe.
w.Header().Set("Docker-Distribution-API-Version", "registry/2.0")
w.WriteHeader(http.StatusOK)
case strings.HasSuffix(r.URL.Path, "/manifests/"+h.tag):
// Manifest HEAD/GET for the recognized tag.
w.Header().Set("Docker-Content-Digest", h.digest)
w.Header().Set("Content-Type", "application/vnd.oci.image.index.v1+json")
body := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.index.v1+json","manifests":[]}`)
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
default:
http.Error(w, "not found", http.StatusNotFound)
}
}

// TestResolveDigest_UsesMirror verifies that when a mirror is configured for
// Docker Hub references, the resolver issues its manifest lookup against the
// mirror rather than registry-1.docker.io. This is the path enterprise
// customers behind an Artifactory / Nexus / Harbor mirror need.
func TestResolveDigest_UsesMirror(t *testing.T) {
const wantDigest = "sha256:48883a67000000000000000000000000000000000000000000000000deadbeef"

mirror := &registryHandler{tag: "latest-cuda", digest: wantDigest}
srv := httptest.NewServer(mirror)
defer srv.Close()

ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

// Reference points at registry-1.docker.io; the mirror should intercept it.
ref := "registry-1.docker.io/docker/docker-model-backend-llamacpp:latest-cuda"
got, err := ResolveDigest(ctx, ref, []string{srv.URL})
if err != nil {
t.Fatalf("ResolveDigest returned error: %v", err)
}
if got != wantDigest {
t.Fatalf("digest mismatch: got %q want %q", got, wantDigest)
}
if mirror.requests.Load() == 0 {
t.Fatalf("expected mirror to be called at least once, got 0 requests")
}
}

// TestResolveDigest_CanceledContext verifies the resolver does not block when
// the context is already canceled. This protects against silent stalls when
// the network path to the upstream/mirror is blackholed (a frequent symptom
// in enterprise networks).
func TestResolveDigest_CanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel()

// No mirror, no real network call should complete. We bound the test
// with a wall-clock deadline so a regression cannot hang CI. A canceled
// context is classified as terminal, so retry must not loop.
done := make(chan struct{})
var resolveErr error
go func() {
_, resolveErr = ResolveDigest(ctx, "registry-1.docker.io/docker/docker-model-backend-llamacpp:latest-cuda", nil)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("ResolveDigest did not return on canceled context within 5s")
}
if resolveErr == nil {
t.Fatalf("expected error on canceled context, got nil")
}
}

// TestRetry_FailsFastOnTerminalError verifies retry does not loop on a
// non-retryable error (e.g. a missing tag / 404). Before this, every error was
// retried 10 times with 1s sleeps (~9s), blocking the install/startup path.
func TestRetry_FailsFastOnTerminalError(t *testing.T) {
var calls int
_, err := retry(t.Context(), 10, time.Second, func() (*v1.Descriptor, error) {
calls++
return nil, errdefs.ErrNotFound
})
if err == nil {
t.Fatalf("expected error on terminal failure, got nil")
}
if calls != 1 {
t.Fatalf("expected exactly 1 attempt on a terminal error, got %d", calls)
}
}

// TestRetry_RetriesTransientError verifies retry still loops the full budget on
// an unclassified (transient) error, preserving the original behavior.
func TestRetry_RetriesTransientError(t *testing.T) {
var calls int
_, err := retry(t.Context(), 3, time.Millisecond, func() (*v1.Descriptor, error) {
calls++
return nil, errors.New("transient network blip")
})
if err == nil {
t.Fatalf("expected error after exhausting attempts, got nil")
}
if calls != 3 {
t.Fatalf("expected 3 attempts on a transient error, got %d", calls)
}
}
12 changes: 12 additions & 0 deletions pkg/internal/dockerhub/testmain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dockerhub

import (
"testing"

"go.uber.org/goleak"
)

// TestMain runs goleak after the test suite to detect goroutine leaks.
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}