diff --git a/cmd/git-sync/internal/sha256convert/sha256convert.go b/cmd/git-sync/internal/sha256convert/sha256convert.go index bc280104..1b25943a 100644 --- a/cmd/git-sync/internal/sha256convert/sha256convert.go +++ b/cmd/git-sync/internal/sha256convert/sha256convert.go @@ -945,7 +945,7 @@ func openSource(ctx context.Context, req Request, planCfg planner.PlanConfig) (g if ep.Scheme != "http" && ep.Scheme != "https" { return nil, nil, nil, fmt.Errorf("convert-sha256 currently supports HTTP/HTTPS sources only; got %q", ep.Scheme) } - authMethod, err := auth.Resolve(auth.Endpoint{ + authMethod, err := auth.Resolve(ctx, auth.Endpoint{ Username: req.SourceAuth.Username, Token: req.SourceAuth.Token, BearerToken: req.SourceAuth.BearerToken, diff --git a/internal/auth/auth.go b/internal/auth/auth.go index dc78ed1c..713695c8 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -31,14 +31,14 @@ type Endpoint struct { // Resolve resolves the auth method for the given endpoint configuration. // Order: explicit flags → Entire DB token → anonymous (with the git credential // helper deferred until the server returns 401, matching git's own behaviour). -func Resolve(raw Endpoint, ep *url.URL) (Method, error) { +func Resolve(ctx context.Context, raw Endpoint, ep *url.URL) (Method, error) { if auth := explicitAuth(raw); auth != nil { return auth, nil } if !isHTTPEndpoint(ep) { return nil, nil //nolint:nilnil // nil signals no auth method found at this stage } - if username, password, ok, err := LookupEntireDBCredential(raw, ep); err != nil { + if username, password, ok, err := LookupEntireDBCredential(ctx, raw, ep); err != nil { return nil, err // issue #7: surface refresh failure explicitly } else if ok { return &transporthttp.BasicAuth{Username: username, Password: password}, nil diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 0ddc9d6d..8ed74698 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -320,7 +320,7 @@ func TestResolve(t *testing.T) { // doesn't find anything. t.Setenv("ENTIRE_CONFIG_DIR", t.TempDir()) - got, err := Resolve(tt.raw, tt.ep) + got, err := Resolve(context.Background(), tt.raw, tt.ep) if tt.wantErr { if err == nil { t.Fatal("expected error, got nil") @@ -859,6 +859,36 @@ func TestGetTokenWithRefresh(t *testing.T) { }) } +// A cancelled caller context must abort the token refresh, proving the +// context is threaded down to the HTTP request rather than dropped for a +// background one. +func TestGetTokenWithRefreshHonorsContext(t *testing.T) { + dir := t.TempDir() + t.Setenv("ENTIRE_TOKEN_STORE", "file") + t.Setenv("ENTIRE_TOKEN_STORE_PATH", filepath.Join(dir, "tokens.json")) + t.Setenv("ENTIRE_CONFIG_DIR", t.TempDir()) + + // Expired access token plus a refresh token, so refresh is attempted. + pastExpiry := time.Now().Add(-1 * time.Hour).Unix() + if err := WriteStoredToken(credentialService("example.com"), "carol", fmt.Sprintf("stale|%d", pastExpiry)); err != nil { + t.Fatalf("WriteStoredToken: %v", err) + } + if err := WriteStoredToken(credentialService("example.com")+":refresh", "carol", "refresh-tok"); err != nil { + t.Fatalf("WriteStoredToken refresh: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel before the refresh HTTP request runs + + _, err := getTokenWithRefresh(ctx, "example.com", "carol", "https://example.invalid", false) + if err == nil { + t.Fatal("expected an error when the context is cancelled") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("error should wrap context.Canceled, got: %v", err) + } +} + func TestReadWriteStoredTokenFileStore(t *testing.T) { dir := t.TempDir() tokenPath := filepath.Join(dir, "tokens.json") @@ -927,7 +957,7 @@ func TestLookupEntireDBTokenNotConfigured(t *testing.T) { configDir := t.TempDir() t.Setenv("ENTIRE_CONFIG_DIR", configDir) - got, err := lookupEntireDBToken("example.com", "https://example.com", false) + got, err := lookupEntireDBToken(context.Background(), "example.com", "https://example.com", false) if err != nil { t.Fatalf("expected nil error, got %v", err) } diff --git a/internal/auth/entiredb.go b/internal/auth/entiredb.go index 72f73f25..907aaf28 100644 --- a/internal/auth/entiredb.go +++ b/internal/auth/entiredb.go @@ -38,12 +38,12 @@ type oauthTokenResponse struct { // Returns (username, password, true, nil) on success, ("", "", false, nil) when // no credential is configured, or ("", "", false, err) when a credential exists // but refresh failed (issue #7). -func LookupEntireDBCredential(raw Endpoint, ep *url.URL) (string, string, bool, error) { +func LookupEntireDBCredential(ctx context.Context, raw Endpoint, ep *url.URL) (string, string, bool, error) { if ep == nil || ep.Host == "" { return "", "", false, nil } credHost := endpointCredentialHost(ep) - token, err := lookupEntireDBToken(credHost, endpointBaseURL(ep), raw.SkipTLSVerify) + token, err := lookupEntireDBToken(ctx, credHost, endpointBaseURL(ep), raw.SkipTLSVerify) if err != nil { return "", "", false, err } @@ -76,7 +76,7 @@ func endpointCredentialHost(ep *url.URL) string { return ep.Host // includes port if present in url.URL } -func lookupEntireDBToken(host, baseURL string, skipTLS bool) (string, error) { +func lookupEntireDBToken(ctx context.Context, host, baseURL string, skipTLS bool) (string, error) { configDir := os.Getenv("ENTIRE_CONFIG_DIR") if configDir == "" { home, err := os.UserHomeDir() @@ -90,7 +90,7 @@ func lookupEntireDBToken(host, baseURL string, skipTLS bool) (string, error) { if !ok || username == "" { return "", nil } - return getTokenWithRefresh(context.Background(), host, username, baseURL, skipTLS) + return getTokenWithRefresh(ctx, host, username, baseURL, skipTLS) } func loadEntireDBActiveUser(host, configDir string) (string, bool) { @@ -109,9 +109,10 @@ func loadEntireDBActiveUser(host, configDir string) (string, bool) { return info.ActiveUser, true } -// getTokenWithRefresh retrieves a token, refreshing it if expired. -// On refresh failure, returns the stale token with a nil error rather than -// propagating the refresh error silently (issue #7). +// getTokenWithRefresh retrieves a token, refreshing it if expired or expiring. +// On refresh failure it returns the error rather than silently reusing the +// stale token, so the caller surfaces the failure instead of authenticating +// with a known-bad credential (issue #7). func getTokenWithRefresh(ctx context.Context, host, username, baseURL string, skipTLS bool) (string, error) { encoded, err := ReadStoredToken(credentialService(host), username) if err != nil { @@ -180,6 +181,9 @@ func refreshAccessToken(ctx context.Context, host, username, baseURL string, ski client := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ + // Honor HTTP(S)_PROXY/NO_PROXY like the default transport; a bare + // &http.Transport{} leaves Proxy nil and bypasses the proxy. + Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{InsecureSkipVerify: skipTLS}, //nolint:gosec // InsecureSkipVerify is controlled by user flag }, } diff --git a/internal/syncer/auth_test.go b/internal/syncer/auth_test.go index d22a51ac..c0b2f8c8 100644 --- a/internal/syncer/auth_test.go +++ b/internal/syncer/auth_test.go @@ -31,7 +31,7 @@ func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) { return nil, nil } - resolved, err := auth.Resolve(auth.Endpoint{ + resolved, err := auth.Resolve(context.Background(), auth.Endpoint{ Username: "git", Token: "explicit-token", }, ep) @@ -50,7 +50,7 @@ func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) { func TestNewHTTPConnSkipTLSVerify(t *testing.T) { stats := newStats(false) - conn, err := newConn(Endpoint{ + conn, err := newConn(context.Background(), Endpoint{ URL: "https://example.com/repo.git", SkipTLSVerify: true, }, "source", stats, nil) @@ -79,7 +79,7 @@ func TestNewHTTPConnUsesProvidedHTTPClient(t *testing.T) { baseTransport := http.DefaultTransport baseClient := &http.Client{Transport: baseTransport} - conn, err := newConn(Endpoint{URL: "https://example.com/repo.git"}, "source", stats, baseClient) + conn, err := newConn(context.Background(), Endpoint{URL: "https://example.com/repo.git"}, "source", stats, baseClient) if err != nil { t.Fatalf("new conn: %v", err) } @@ -125,7 +125,7 @@ func TestResolveAuthMethodUsesEntireDBStoredToken(t *testing.T) { return nil, nil } - resolved, err := auth.Resolve(auth.Endpoint{}, ep) + resolved, err := auth.Resolve(context.Background(), auth.Endpoint{}, ep) if err != nil { t.Fatalf("resolve auth: %v", err) } @@ -177,7 +177,7 @@ func TestResolveAuthMethodRefreshesExpiredEntireDBToken(t *testing.T) { t.Fatalf("write refresh token: %v", err) } - resolved, err := auth.Resolve(auth.Endpoint{SkipTLSVerify: true}, ep) + resolved, err := auth.Resolve(context.Background(), auth.Endpoint{SkipTLSVerify: true}, ep) if err != nil { t.Fatalf("resolve auth: %v", err) } diff --git a/internal/syncer/git_http_backend_test.go b/internal/syncer/git_http_backend_test.go index cf15e6aa..4bf4de14 100644 --- a/internal/syncer/git_http_backend_test.go +++ b/internal/syncer/git_http_backend_test.go @@ -515,7 +515,7 @@ func TestBootstrap_GitHTTPBackendBatchedBranchResume(t *testing.T) { } stats := newStats(false) - sourceConn, err := newConn(cfg.Source, "source", stats, nil) + sourceConn, err := newConn(context.Background(), cfg.Source, "source", stats, nil) if err != nil { t.Fatalf("create source transport: %v", err) } @@ -612,7 +612,7 @@ func TestBootstrap_GitHTTPBackendBatchedPlanningTracksBatchLimit(t *testing.T) { } stats := newStats(false) - sourceConn, err := newConn(cfg.Source, "source", stats, nil) + sourceConn, err := newConn(context.Background(), cfg.Source, "source", stats, nil) if err != nil { t.Fatalf("create source transport: %v", err) } diff --git a/internal/syncer/syncer.go b/internal/syncer/syncer.go index 88ae768d..a1fc6aea 100644 --- a/internal/syncer/syncer.go +++ b/internal/syncer/syncer.go @@ -349,7 +349,7 @@ func measurementLine(m Measurement) []string { // --- Session setup --- -func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (gitproto.Conn, error) { +func newConn(ctx context.Context, raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (gitproto.Conn, error) { ep, err := transport.ParseURL(raw.URL) if err != nil { return nil, fmt.Errorf("parse endpoint: %w", err) @@ -369,7 +369,7 @@ func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http BearerToken: raw.BearerToken, SkipTLSVerify: raw.SkipTLSVerify, } - authMethod, err := auth.Resolve(authEp, ep) + authMethod, err := auth.Resolve(ctx, authEp, ep) if err != nil { return nil, fmt.Errorf("resolve auth: %w", err) } @@ -679,7 +679,7 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, })) } - s.sourceConn, err = newConn(cfg.Source, "source", s.stats, cfg.HTTPClient) + s.sourceConn, err = newConn(ctx, cfg.Source, "source", s.stats, cfg.HTTPClient) if err != nil { return nil, fmt.Errorf("create source transport: %w", err) } @@ -696,7 +696,7 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, s.sourceRefMap = gitproto.RefHashMap(sourceRefs) if needTarget { - targetConn, err := newConn(cfg.Target, "target", s.stats, cfg.HTTPClient) + targetConn, err := newConn(ctx, cfg.Target, "target", s.stats, cfg.HTTPClient) if err != nil { return nil, fmt.Errorf("create target transport: %w", err) } diff --git a/internal/syncer/syncer_test.go b/internal/syncer/syncer_test.go index af162258..148b92a4 100644 --- a/internal/syncer/syncer_test.go +++ b/internal/syncer/syncer_test.go @@ -133,7 +133,7 @@ func TestFinalizeCountsTalliesPushedAndDeleted(t *testing.T) { func TestGitHubOwnerRepo(t *testing.T) { stats := newStats(false) - conn, err := newConn(Endpoint{URL: "https://github.com/torvalds/linux.git"}, "source", stats, nil) + conn, err := newConn(context.Background(), Endpoint{URL: "https://github.com/torvalds/linux.git"}, "source", stats, nil) if err != nil { t.Fatalf("new conn: %v", err) } @@ -148,7 +148,7 @@ func TestGitHubOwnerRepo(t *testing.T) { func TestGitHubOwnerRepoRejectsNonGitHubSource(t *testing.T) { stats := newStats(false) - conn, err := newConn(Endpoint{URL: "https://gitlab.com/group/project.git"}, "source", stats, nil) + conn, err := newConn(context.Background(), Endpoint{URL: "https://gitlab.com/group/project.git"}, "source", stats, nil) if err != nil { t.Fatalf("new conn: %v", err) } @@ -225,7 +225,7 @@ func TestProbeWithoutTargetIgnoresEndpointEqualityCheck(t *testing.T) { func TestNewHTTPConn_PropagatesFollowInfoRefsRedirect(t *testing.T) { stats := newStats(false) - off, err := newConn(Endpoint{URL: "https://node.example/repo.git"}, "target", stats, nil) + off, err := newConn(context.Background(), Endpoint{URL: "https://node.example/repo.git"}, "target", stats, nil) if err != nil { t.Fatalf("new conn (off): %v", err) } @@ -237,7 +237,7 @@ func TestNewHTTPConn_PropagatesFollowInfoRefsRedirect(t *testing.T) { t.Error("FollowInfoRefsRedirect should default to false") } - on, err := newConn(Endpoint{URL: "https://node.example/repo.git", FollowInfoRefsRedirect: true}, "target", stats, nil) + on, err := newConn(context.Background(), Endpoint{URL: "https://node.example/repo.git", FollowInfoRefsRedirect: true}, "target", stats, nil) if err != nil { t.Fatalf("new conn (on): %v", err) } @@ -269,7 +269,7 @@ func TestNewConnBuildsSSHTransport(t *testing.T) { } for _, raw := range tests { t.Run(raw, func(t *testing.T) { - conn, err := newConn(Endpoint{URL: raw}, "source", stats, nil) + conn, err := newConn(context.Background(), Endpoint{URL: raw}, "source", stats, nil) if err != nil { t.Fatalf("new conn: %v", err) }