From 013775d696e074e3d8c030c34af0cecca2e86c8c Mon Sep 17 00:00:00 2001
From: Paulo Gomes <paulo@entire.io>
Date: Thu, 16 Apr 2026 12:14:28 +0100
Subject: [PATCH] plumbing: transport/http, Add support for followRedirects
 policy Back-port from #1997.

Signed-off-by: Paulo Gomes <paulo@entire.io>
---
 plumbing/transport/http/common.go | 168 ++++++++++++++++++++++++++----
 1 file changed, 147 insertions(+), 21 deletions(-)

diff --git a/plumbing/transport/http/common.go b/plumbing/transport/http/common.go
index 5dd2e311..83f93f16 100644
--- a/plumbing/transport/http/common.go
+++ b/plumbing/transport/http/common.go
@@ -7,7 +7,6 @@ import (
 	"crypto/tls"
 	"crypto/x509"
 	"fmt"
-	"net"
 	"net/http"
 	"net/url"
 	"reflect"
@@ -24,6 +23,33 @@ import (
 	"github.com/go-git/go-git/v5/utils/ioutil"
 )
 
+type contextKey int
+
+const initialRequestKey contextKey = iota
+
+// RedirectPolicy controls how the HTTP transport follows redirects.
+//
+// The values mirror Git's http.followRedirects config:
+// "true" follows redirects for all requests, "false" treats redirects as
+// errors, and "initial" follows redirects only for the initial
+// /info/refs discovery request. The zero value defaults to "initial".
+type RedirectPolicy string
+
+const (
+	FollowInitialRedirects RedirectPolicy = "initial"
+	FollowRedirects        RedirectPolicy = "true"
+	NoFollowRedirects      RedirectPolicy = "false"
+)
+
+func withInitialRequest(ctx context.Context) context.Context {
+	return context.WithValue(ctx, initialRequestKey, true)
+}
+
+func isInitialRequest(req *http.Request) bool {
+	v, _ := req.Context().Value(initialRequestKey).(bool)
+	return v
+}
+
 // it requires a bytes.Buffer, because we need to know the length
 func applyHeadersToRequest(req *http.Request, content *bytes.Buffer, host string, requestType string) {
 	req.Header.Add("User-Agent", capability.DefaultAgent())
@@ -54,12 +80,15 @@ func advertisedReferences(ctx context.Context, s *session, serviceName string) (
 
 	s.ApplyAuthToRequest(req)
 	applyHeadersToRequest(req, nil, s.endpoint.Host, serviceName)
-	res, err := s.client.Do(req.WithContext(ctx))
+	res, err := s.client.Do(req.WithContext(withInitialRequest(ctx)))
 	if err != nil {
 		return nil, err
 	}
 
-	s.ModifyEndpointIfRedirect(res)
+	if err := s.ModifyEndpointIfRedirect(res); err != nil {
+		_ = res.Body.Close()
+		return nil, err
+	}
 	defer ioutil.CheckClose(res.Body, &err)
 
 	if err = NewErr(res); err != nil {
@@ -96,6 +125,7 @@ type client struct {
 	client     *http.Client
 	transports *lru.Cache
 	mutex      sync.RWMutex
+	follow     RedirectPolicy
 }
 
 // ClientOptions holds user configurable options for the client.
@@ -106,6 +136,11 @@ type ClientOptions struct {
 	// size, will result in the least recently used transport getting deleted
 	// before the provided transport is added to the cache.
 	CacheMaxEntries int
+
+	// RedirectPolicy controls redirect handling. Supported values are
+	// "true", "false", and "initial". The zero value defaults to
+	// "initial", matching Git's http.followRedirects default.
+	RedirectPolicy RedirectPolicy
 }
 
 var (
@@ -150,12 +185,16 @@ func NewClientWithOptions(c *http.Client, opts *ClientOptions) transport.Transpo
 	}
 	cl := &client{
 		client: c,
+		follow: FollowInitialRedirects,
 	}
 
 	if opts != nil {
 		if opts.CacheMaxEntries > 0 {
 			cl.transports = lru.New(opts.CacheMaxEntries)
 		}
+		if opts.RedirectPolicy != "" {
+			cl.follow = opts.RedirectPolicy
+		}
 	}
 	return cl
 }
@@ -289,14 +328,9 @@ func newSession(c *client, ep *transport.Endpoint, auth transport.AuthMethod) (*
 			}
 		}
 
-		httpClient = &http.Client{
-			Transport:     transport,
-			CheckRedirect: c.client.CheckRedirect,
-			Jar:           c.client.Jar,
-			Timeout:       c.client.Timeout,
-		}
+		httpClient = c.cloneHTTPClient(transport)
 	} else {
-		httpClient = c.client
+		httpClient = c.cloneHTTPClient(c.client.Transport)
 	}
 
 	s := &session{
@@ -324,30 +358,122 @@ func (s *session) ApplyAuthToRequest(req *http.Request) {
 	s.auth.SetAuth(req)
 }
 
-func (s *session) ModifyEndpointIfRedirect(res *http.Response) {
+func (s *session) ModifyEndpointIfRedirect(res *http.Response) error {
 	if res.Request == nil {
-		return
+		return nil
+	}
+	if s.endpoint == nil {
+		return fmt.Errorf("http redirect: nil endpoint")
 	}
 
 	r := res.Request
 	if !strings.HasSuffix(r.URL.Path, infoRefsPath) {
-		return
+		return fmt.Errorf("http redirect: target %q does not end with %s", r.URL.Path, infoRefsPath)
+	}
+	if r.URL.Scheme != "http" && r.URL.Scheme != "https" {
+		return fmt.Errorf("http redirect: unsupported scheme %q", r.URL.Scheme)
+	}
+	if r.URL.Scheme != s.endpoint.Protocol &&
+		!(s.endpoint.Protocol == "http" && r.URL.Scheme == "https") {
+		return fmt.Errorf("http redirect: changes scheme from %q to %q", s.endpoint.Protocol, r.URL.Scheme)
 	}
 
-	h, p, err := net.SplitHostPort(r.URL.Host)
+	host := endpointHost(r.URL.Hostname())
+	port, err := endpointPort(r.URL.Port())
 	if err != nil {
-		h = r.URL.Host
+		return err
 	}
-	if p != "" {
-		port, err := strconv.Atoi(p)
-		if err == nil {
-			s.endpoint.Port = port
-		}
+
+	if host != s.endpoint.Host || effectivePort(r.URL.Scheme, port) != effectivePort(s.endpoint.Protocol, s.endpoint.Port) {
+		s.endpoint.User = ""
+		s.endpoint.Password = ""
+		s.auth = nil
 	}
-	s.endpoint.Host = h
+
+	s.endpoint.Host = host
+	s.endpoint.Port = port
 
 	s.endpoint.Protocol = r.URL.Scheme
 	s.endpoint.Path = r.URL.Path[:len(r.URL.Path)-len(infoRefsPath)]
+	return nil
+}
+
+func endpointHost(host string) string {
+	if strings.Contains(host, ":") {
+		return "[" + host + "]"
+	}
+
+	return host
+}
+
+func endpointPort(port string) (int, error) {
+	if port == "" {
+		return 0, nil
+	}
+
+	parsed, err := strconv.Atoi(port)
+	if err != nil {
+		return 0, fmt.Errorf("http redirect: invalid port %q", port)
+	}
+
+	return parsed, nil
+}
+
+func effectivePort(scheme string, port int) int {
+	if port != 0 {
+		return port
+	}
+
+	switch strings.ToLower(scheme) {
+	case "http":
+		return 80
+	case "https":
+		return 443
+	default:
+		return 0
+	}
+}
+
+func (c *client) cloneHTTPClient(transport http.RoundTripper) *http.Client {
+	return &http.Client{
+		Transport:     transport,
+		CheckRedirect: wrapCheckRedirect(c.follow, c.client.CheckRedirect),
+		Jar:           c.client.Jar,
+		Timeout:       c.client.Timeout,
+	}
+}
+
+func wrapCheckRedirect(policy RedirectPolicy, next func(*http.Request, []*http.Request) error) func(*http.Request, []*http.Request) error {
+	return func(req *http.Request, via []*http.Request) error {
+		if err := checkRedirect(req, via, policy); err != nil {
+			return err
+		}
+		if next != nil {
+			return next(req, via)
+		}
+		return nil
+	}
+}
+
+func checkRedirect(req *http.Request, via []*http.Request, policy RedirectPolicy) error {
+	switch policy {
+	case FollowRedirects:
+	case NoFollowRedirects:
+		return fmt.Errorf("http redirect: redirects disabled to %s", req.URL)
+	case "", FollowInitialRedirects:
+		if !isInitialRequest(req) {
+			return fmt.Errorf("http redirect: redirect on non-initial request to %s", req.URL)
+		}
+	default:
+		return fmt.Errorf("http redirect: invalid redirect policy %q", policy)
+	}
+	if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
+		return fmt.Errorf("http redirect: unsupported scheme %q", req.URL.Scheme)
+	}
+	if len(via) >= 10 {
+		return fmt.Errorf("http redirect: too many redirects")
+	}
+	return nil
 }
 
 func (*session) Close() error {
-- 
2.54.0

