Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/git-sync/internal/sha256convert/sha256convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ func openSource(ctx context.Context, req Request, planCfg planner.PlanConfig) (g
if ep.Scheme != "http" && ep.Scheme != "https" {
return nil, nil, nil, fmt.Errorf("convert-sha256 currently supports HTTP/HTTPS sources only; got %q", ep.Scheme)
}
authMethod, err := auth.Resolve(auth.Endpoint{
authMethod, err := auth.Resolve(ctx, auth.Endpoint{
Username: req.SourceAuth.Username,
Token: req.SourceAuth.Token,
BearerToken: req.SourceAuth.BearerToken,
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ type Endpoint struct {
// Resolve resolves the auth method for the given endpoint configuration.
// Order: explicit flags → Entire DB token → anonymous (with the git credential
// helper deferred until the server returns 401, matching git's own behaviour).
func Resolve(raw Endpoint, ep *url.URL) (Method, error) {
func Resolve(ctx context.Context, raw Endpoint, ep *url.URL) (Method, error) {
if auth := explicitAuth(raw); auth != nil {
return auth, nil
}
if !isHTTPEndpoint(ep) {
return nil, nil //nolint:nilnil // nil signals no auth method found at this stage
}
if username, password, ok, err := LookupEntireDBCredential(raw, ep); err != nil {
if username, password, ok, err := LookupEntireDBCredential(ctx, raw, ep); err != nil {
return nil, err // issue #7: surface refresh failure explicitly
} else if ok {
return &transporthttp.BasicAuth{Username: username, Password: password}, nil
Expand Down
34 changes: 32 additions & 2 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func TestResolve(t *testing.T) {
// doesn't find anything.
t.Setenv("ENTIRE_CONFIG_DIR", t.TempDir())

got, err := Resolve(tt.raw, tt.ep)
got, err := Resolve(context.Background(), tt.raw, tt.ep)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
Expand Down Expand Up @@ -859,6 +859,36 @@ func TestGetTokenWithRefresh(t *testing.T) {
})
}

// A cancelled caller context must abort the token refresh, proving the
// context is threaded down to the HTTP request rather than dropped for a
// background one.
func TestGetTokenWithRefreshHonorsContext(t *testing.T) {
dir := t.TempDir()
t.Setenv("ENTIRE_TOKEN_STORE", "file")
t.Setenv("ENTIRE_TOKEN_STORE_PATH", filepath.Join(dir, "tokens.json"))
t.Setenv("ENTIRE_CONFIG_DIR", t.TempDir())

// Expired access token plus a refresh token, so refresh is attempted.
pastExpiry := time.Now().Add(-1 * time.Hour).Unix()
if err := WriteStoredToken(credentialService("example.com"), "carol", fmt.Sprintf("stale|%d", pastExpiry)); err != nil {
t.Fatalf("WriteStoredToken: %v", err)
}
if err := WriteStoredToken(credentialService("example.com")+":refresh", "carol", "refresh-tok"); err != nil {
t.Fatalf("WriteStoredToken refresh: %v", err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel before the refresh HTTP request runs

_, err := getTokenWithRefresh(ctx, "example.com", "carol", "https://example.invalid", false)
if err == nil {
t.Fatal("expected an error when the context is cancelled")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("error should wrap context.Canceled, got: %v", err)
}
}

func TestReadWriteStoredTokenFileStore(t *testing.T) {
dir := t.TempDir()
tokenPath := filepath.Join(dir, "tokens.json")
Expand Down Expand Up @@ -927,7 +957,7 @@ func TestLookupEntireDBTokenNotConfigured(t *testing.T) {
configDir := t.TempDir()
t.Setenv("ENTIRE_CONFIG_DIR", configDir)

got, err := lookupEntireDBToken("example.com", "https://example.com", false)
got, err := lookupEntireDBToken(context.Background(), "example.com", "https://example.com", false)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
Expand Down
18 changes: 11 additions & 7 deletions internal/auth/entiredb.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ type oauthTokenResponse struct {
// Returns (username, password, true, nil) on success, ("", "", false, nil) when
// no credential is configured, or ("", "", false, err) when a credential exists
// but refresh failed (issue #7).
func LookupEntireDBCredential(raw Endpoint, ep *url.URL) (string, string, bool, error) {
func LookupEntireDBCredential(ctx context.Context, raw Endpoint, ep *url.URL) (string, string, bool, error) {
if ep == nil || ep.Host == "" {
return "", "", false, nil
}
credHost := endpointCredentialHost(ep)
token, err := lookupEntireDBToken(credHost, endpointBaseURL(ep), raw.SkipTLSVerify)
token, err := lookupEntireDBToken(ctx, credHost, endpointBaseURL(ep), raw.SkipTLSVerify)
if err != nil {
return "", "", false, err
}
Expand Down Expand Up @@ -76,7 +76,7 @@ func endpointCredentialHost(ep *url.URL) string {
return ep.Host // includes port if present in url.URL
}

func lookupEntireDBToken(host, baseURL string, skipTLS bool) (string, error) {
func lookupEntireDBToken(ctx context.Context, host, baseURL string, skipTLS bool) (string, error) {
configDir := os.Getenv("ENTIRE_CONFIG_DIR")
if configDir == "" {
home, err := os.UserHomeDir()
Expand All @@ -90,7 +90,7 @@ func lookupEntireDBToken(host, baseURL string, skipTLS bool) (string, error) {
if !ok || username == "" {
return "", nil
}
return getTokenWithRefresh(context.Background(), host, username, baseURL, skipTLS)
return getTokenWithRefresh(ctx, host, username, baseURL, skipTLS)
}

func loadEntireDBActiveUser(host, configDir string) (string, bool) {
Expand All @@ -109,9 +109,10 @@ func loadEntireDBActiveUser(host, configDir string) (string, bool) {
return info.ActiveUser, true
}

// getTokenWithRefresh retrieves a token, refreshing it if expired.
// On refresh failure, returns the stale token with a nil error rather than
// propagating the refresh error silently (issue #7).
// getTokenWithRefresh retrieves a token, refreshing it if expired or expiring.
// On refresh failure it returns the error rather than silently reusing the
// stale token, so the caller surfaces the failure instead of authenticating
// with a known-bad credential (issue #7).
func getTokenWithRefresh(ctx context.Context, host, username, baseURL string, skipTLS bool) (string, error) {
encoded, err := ReadStoredToken(credentialService(host), username)
if err != nil {
Expand Down Expand Up @@ -180,6 +181,9 @@ func refreshAccessToken(ctx context.Context, host, username, baseURL string, ski
client := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
// Honor HTTP(S)_PROXY/NO_PROXY like the default transport; a bare
// &http.Transport{} leaves Proxy nil and bypasses the proxy.
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: skipTLS}, //nolint:gosec // InsecureSkipVerify is controlled by user flag
},
}
Expand Down
10 changes: 5 additions & 5 deletions internal/syncer/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) {
return nil, nil
}

resolved, err := auth.Resolve(auth.Endpoint{
resolved, err := auth.Resolve(context.Background(), auth.Endpoint{
Username: "git",
Token: "explicit-token",
}, ep)
Expand All @@ -50,7 +50,7 @@ func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) {

func TestNewHTTPConnSkipTLSVerify(t *testing.T) {
stats := newStats(false)
conn, err := newConn(Endpoint{
conn, err := newConn(context.Background(), Endpoint{
URL: "https://example.com/repo.git",
SkipTLSVerify: true,
}, "source", stats, nil)
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestNewHTTPConnUsesProvidedHTTPClient(t *testing.T) {
baseTransport := http.DefaultTransport
baseClient := &http.Client{Transport: baseTransport}

conn, err := newConn(Endpoint{URL: "https://example.com/repo.git"}, "source", stats, baseClient)
conn, err := newConn(context.Background(), Endpoint{URL: "https://example.com/repo.git"}, "source", stats, baseClient)
if err != nil {
t.Fatalf("new conn: %v", err)
}
Expand Down Expand Up @@ -125,7 +125,7 @@ func TestResolveAuthMethodUsesEntireDBStoredToken(t *testing.T) {
return nil, nil
}

resolved, err := auth.Resolve(auth.Endpoint{}, ep)
resolved, err := auth.Resolve(context.Background(), auth.Endpoint{}, ep)
if err != nil {
t.Fatalf("resolve auth: %v", err)
}
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestResolveAuthMethodRefreshesExpiredEntireDBToken(t *testing.T) {
t.Fatalf("write refresh token: %v", err)
}

resolved, err := auth.Resolve(auth.Endpoint{SkipTLSVerify: true}, ep)
resolved, err := auth.Resolve(context.Background(), auth.Endpoint{SkipTLSVerify: true}, ep)
if err != nil {
t.Fatalf("resolve auth: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/syncer/git_http_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func TestBootstrap_GitHTTPBackendBatchedBranchResume(t *testing.T) {
}

stats := newStats(false)
sourceConn, err := newConn(cfg.Source, "source", stats, nil)
sourceConn, err := newConn(context.Background(), cfg.Source, "source", stats, nil)
if err != nil {
t.Fatalf("create source transport: %v", err)
}
Expand Down Expand Up @@ -612,7 +612,7 @@ func TestBootstrap_GitHTTPBackendBatchedPlanningTracksBatchLimit(t *testing.T) {
}

stats := newStats(false)
sourceConn, err := newConn(cfg.Source, "source", stats, nil)
sourceConn, err := newConn(context.Background(), cfg.Source, "source", stats, nil)
if err != nil {
t.Fatalf("create source transport: %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/syncer/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func measurementLine(m Measurement) []string {

// --- Session setup ---

func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (gitproto.Conn, error) {
func newConn(ctx context.Context, raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (gitproto.Conn, error) {
ep, err := transport.ParseURL(raw.URL)
if err != nil {
return nil, fmt.Errorf("parse endpoint: %w", err)
Expand All @@ -369,7 +369,7 @@ func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http
BearerToken: raw.BearerToken,
SkipTLSVerify: raw.SkipTLSVerify,
}
authMethod, err := auth.Resolve(authEp, ep)
authMethod, err := auth.Resolve(ctx, authEp, ep)
if err != nil {
return nil, fmt.Errorf("resolve auth: %w", err)
}
Expand Down Expand Up @@ -679,7 +679,7 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession,
}))
}

s.sourceConn, err = newConn(cfg.Source, "source", s.stats, cfg.HTTPClient)
s.sourceConn, err = newConn(ctx, cfg.Source, "source", s.stats, cfg.HTTPClient)
if err != nil {
return nil, fmt.Errorf("create source transport: %w", err)
}
Expand All @@ -696,7 +696,7 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession,
s.sourceRefMap = gitproto.RefHashMap(sourceRefs)

if needTarget {
targetConn, err := newConn(cfg.Target, "target", s.stats, cfg.HTTPClient)
targetConn, err := newConn(ctx, cfg.Target, "target", s.stats, cfg.HTTPClient)
if err != nil {
return nil, fmt.Errorf("create target transport: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions internal/syncer/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestFinalizeCountsTalliesPushedAndDeleted(t *testing.T) {

func TestGitHubOwnerRepo(t *testing.T) {
stats := newStats(false)
conn, err := newConn(Endpoint{URL: "https://github.com/torvalds/linux.git"}, "source", stats, nil)
conn, err := newConn(context.Background(), Endpoint{URL: "https://github.com/torvalds/linux.git"}, "source", stats, nil)
if err != nil {
t.Fatalf("new conn: %v", err)
}
Expand All @@ -148,7 +148,7 @@ func TestGitHubOwnerRepo(t *testing.T) {

func TestGitHubOwnerRepoRejectsNonGitHubSource(t *testing.T) {
stats := newStats(false)
conn, err := newConn(Endpoint{URL: "https://gitlab.com/group/project.git"}, "source", stats, nil)
conn, err := newConn(context.Background(), Endpoint{URL: "https://gitlab.com/group/project.git"}, "source", stats, nil)
if err != nil {
t.Fatalf("new conn: %v", err)
}
Expand Down Expand Up @@ -225,7 +225,7 @@ func TestProbeWithoutTargetIgnoresEndpointEqualityCheck(t *testing.T) {
func TestNewHTTPConn_PropagatesFollowInfoRefsRedirect(t *testing.T) {
stats := newStats(false)

off, err := newConn(Endpoint{URL: "https://node.example/repo.git"}, "target", stats, nil)
off, err := newConn(context.Background(), Endpoint{URL: "https://node.example/repo.git"}, "target", stats, nil)
if err != nil {
t.Fatalf("new conn (off): %v", err)
}
Expand All @@ -237,7 +237,7 @@ func TestNewHTTPConn_PropagatesFollowInfoRefsRedirect(t *testing.T) {
t.Error("FollowInfoRefsRedirect should default to false")
}

on, err := newConn(Endpoint{URL: "https://node.example/repo.git", FollowInfoRefsRedirect: true}, "target", stats, nil)
on, err := newConn(context.Background(), Endpoint{URL: "https://node.example/repo.git", FollowInfoRefsRedirect: true}, "target", stats, nil)
if err != nil {
t.Fatalf("new conn (on): %v", err)
}
Expand Down Expand Up @@ -269,7 +269,7 @@ func TestNewConnBuildsSSHTransport(t *testing.T) {
}
for _, raw := range tests {
t.Run(raw, func(t *testing.T) {
conn, err := newConn(Endpoint{URL: raw}, "source", stats, nil)
conn, err := newConn(context.Background(), Endpoint{URL: raw}, "source", stats, nil)
if err != nil {
t.Fatalf("new conn: %v", err)
}
Expand Down
Loading