diff --git a/cli/internal/tygerproxy/tygerproxy.go b/cli/internal/tygerproxy/tygerproxy.go index c0541bd1..2239c66f 100644 --- a/cli/internal/tygerproxy/tygerproxy.go +++ b/cli/internal/tygerproxy/tygerproxy.go @@ -189,13 +189,29 @@ func CheckProxyAlreadyRunning(options *ProxyOptions) (*ProxyServiceMetadata, err return existingProxy, ErrProxyNotRunning } - if existingProxy.ServerUrl != options.ServerUrl { + if !urlsEquivalent(existingProxy.ServerUrl, options.ServerUrl) { return existingProxy, ErrProxyAlreadyRunningWrongTarget } return existingProxy, nil } +func urlsEquivalent(a, b string) bool { + return normalizeUrl(a) == normalizeUrl(b) +} + +func normalizeUrl(raw string) string { + u, err := url.Parse(raw) + if err != nil { + return raw + } + u.Scheme = strings.ToLower(u.Scheme) + u.Host = strings.ToLower(u.Host) + u.Path = strings.TrimRight(u.Path, "/") + u.RawQuery = u.Query().Encode() + return u.String() +} + func GetExistingProxyMetadata(options *ProxyOptions) *ProxyServiceMetadata { // note: not using retryablehttp here because we are hitting localhost // and we want to fail quickly diff --git a/cli/internal/tygerproxy/tygerproxy_test.go b/cli/internal/tygerproxy/tygerproxy_test.go new file mode 100644 index 00000000..7e535507 --- /dev/null +++ b/cli/internal/tygerproxy/tygerproxy_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package tygerproxy + +import "testing" + +func TestUrlsEquivalent(t *testing.T) { + tests := []struct { + name string + a string + b string + want bool + }{ + {"identical", "http://myserver:8080", "http://myserver:8080", true}, + {"scheme casing", "http://myserver:8080", "HTTP://myserver:8080", true}, + {"host casing", "http://MyServer:8080", "http://myserver:8080", true}, + {"mixed case scheme and host", "HTTPS://MyServer.Example.COM/api", "https://myserver.example.com/api", true}, + {"escaped vs unescaped brackets in query", + "ssh://user@myhost/opt/tyger/api.sock?option[StrictHostKeyChecking]=no&option[UserKnownHostsFile]=NUL", + "ssh://user@myhost/opt/tyger/api.sock?option%5BStrictHostKeyChecking%5D=no&option%5BUserKnownHostsFile%5D=NUL", + true}, + {"escaped vs unescaped brackets reversed", + "ssh://user@myhost/opt/tyger/api.sock?option%5BStrictHostKeyChecking%5D=no&option%5BUserKnownHostsFile%5D=NUL", + "ssh://user@myhost/opt/tyger/api.sock?option[StrictHostKeyChecking]=no&option[UserKnownHostsFile]=NUL", + true}, + {"different servers", "http://server-a:8080", "http://server-b:8080", false}, + {"different ports", "http://myserver:8080", "http://myserver:9090", false}, + {"different schemes", "http://myserver:8080", "https://myserver:8080", false}, + {"different query values", + "ssh://user@myhost/opt/tyger/api.sock?option[StrictHostKeyChecking]=no", + "ssh://user@myhost/opt/tyger/api.sock?option[StrictHostKeyChecking]=yes", + false}, + {"different paths", "http://myserver:8080/a", "http://myserver:8080/b", false}, + {"trailing slash", "http://myserver:8080/api/", "http://myserver:8080/api", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := urlsEquivalent(tt.a, tt.b) + if got != tt.want { + t.Errorf("UrlsEquivalent(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +}