diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e43b0f9 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.DS_Store diff --git a/urlbuilder/urlbuilder.go b/urlbuilder/urlbuilder.go index 63804c2..f9cf723 100644 --- a/urlbuilder/urlbuilder.go +++ b/urlbuilder/urlbuilder.go @@ -11,23 +11,57 @@ import ( type URLBuilder struct { scheme string host string - path []string + path []val query url.Values fragment string } +type val struct { + val string + shouldEscape bool +} + // New creates a new URLBuilder with the given scheme and host -func New(scheme string, host string) *URLBuilder { +func New() *URLBuilder { return &URLBuilder{ - scheme: scheme, - host: host, - query: url.Values{}, + query: make(url.Values), } } +// Scheme creates a new URLBuilder with the given scheme +func Scheme(scheme string) *URLBuilder { + return New().Scheme(scheme) +} + +// Host creates a new URLBuilder with the given host +// The scheme will be protocol-relative +func Host(host string) *URLBuilder { + return New().Host(host) +} + +// Path creates a new URLBuilder with the given path segment +// This path will *not* be escaped +func Path(segment string) *URLBuilder { + ub := New() + ub.path = []val{{val: segment, shouldEscape: false}} + return ub +} + +// Scheme sets the scheme of the URL +func (ub *URLBuilder) Scheme(scheme string) *URLBuilder { + ub.scheme = scheme + return ub +} + +// Host sets the host of the URL +func (ub *URLBuilder) Host(host string) *URLBuilder { + ub.host = host + return ub +} + // Path adds a path segment to the URL func (ub *URLBuilder) Path(segment string) *URLBuilder { - ub.path = append(ub.path, segment) + ub.path = append(ub.path, val{val: segment, shouldEscape: true}) return ub } @@ -46,13 +80,34 @@ func (ub *URLBuilder) Fragment(fragment string) *URLBuilder { // Build constructs the final URL as a SafeURL func (ub *URLBuilder) Build() templ.SafeURL { var buf strings.Builder - buf.WriteString(ub.scheme) - buf.WriteString("://") - buf.WriteString(ub.host) + switch ub.scheme { + case "tel", "mailto": + buf.WriteString(ub.scheme) + buf.WriteByte(':') + buf.WriteString(ub.host) + return templ.SafeURL(buf.String()) + default: + if ub.scheme != "" { + buf.WriteString(ub.scheme) + buf.WriteByte(':') + } + } + + if ub.host != "" { + buf.WriteString("//") + buf.WriteString(ub.host) + } for _, segment := range ub.path { - buf.WriteByte('/') - buf.WriteString(url.PathEscape(segment)) + + if !strings.HasPrefix(segment.val, "/") { + buf.WriteByte('/') + } + if segment.shouldEscape { + buf.WriteString(url.PathEscape(segment.val)) + } else { + buf.WriteString(segment.val) + } } if len(ub.query) > 0 { @@ -65,5 +120,5 @@ func (ub *URLBuilder) Build() templ.SafeURL { buf.WriteString(url.QueryEscape(ub.fragment)) } - return templ.URL(buf.String()) + return templ.SafeURL((buf.String())) } diff --git a/urlbuilder/urlbuilder_test.go b/urlbuilder/urlbuilder_test.go index c02d237..79cd8f5 100644 --- a/urlbuilder/urlbuilder_test.go +++ b/urlbuilder/urlbuilder_test.go @@ -10,7 +10,9 @@ func BenchmarkURLBuilder(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - New("https", "example.com"). + New(). + Scheme("https"). + Host("example.com"). Path("a"). Path("b"). Path("c"). @@ -28,7 +30,8 @@ func BenchmarkURLBuilder(b *testing.B) { func TestBasicURL(t *testing.T) { t.Parallel() - got := New("https", "example.com"). + got := Scheme("https"). + Host("example.com"). Build() expected := templ.URL("https://example.com") @@ -42,7 +45,8 @@ func TestURLWithPaths(t *testing.T) { t.Parallel() c := "c" - got := New("https", "example.com"). + got := Scheme("https"). + Host("example.com"). Path("a"). Path("b"). Path(c). @@ -58,7 +62,8 @@ func TestURLWithPaths(t *testing.T) { func TestURLWithMultipleQueries(t *testing.T) { t.Parallel() - got := New("https", "example.com"). + got := Scheme("https"). + Host("example.com"). Path("path"). Query("key1", "value1"). Query("key2", "value2"). @@ -74,11 +79,12 @@ func TestURLWithMultipleQueries(t *testing.T) { func TestURLWithNoPaths(t *testing.T) { t.Parallel() - got := New("http", "example.org"). + got := Scheme("https"). + Host("example.com"). Query("search", "golang"). Build() - expected := templ.URL("http://example.org?search=golang") + expected := templ.URL("https://example.com?search=golang") if got != expected { t.Fatalf("got %s, want %s", got, expected) @@ -88,7 +94,8 @@ func TestURLWithNoPaths(t *testing.T) { func TestURLEscapingPath(t *testing.T) { t.Parallel() - got := New("https", "example.com"). + got := Scheme("https"). + Host("example.com"). Path("a/b"). Path("c d"). Build() @@ -103,7 +110,8 @@ func TestURLEscapingPath(t *testing.T) { func TestURLEscapingQuery(t *testing.T) { t.Parallel() - got := New("https", "example.com"). + got := Scheme("https"). + Host("example.com"). Query("key with space", "value with space"). Query("key/with/slash", "value/with/slash"). Build() @@ -114,3 +122,126 @@ func TestURLEscapingQuery(t *testing.T) { t.Fatalf("got %s, want %s", got, expected) } } + +func TestPath(t *testing.T) { + t.Parallel() + + got := Path("chat"). + Path("response"). + Query("input", "hello!"). + Build() + + expected := templ.URL("/chat/response?input=hello%21") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestProtocolRelative(t *testing.T) { + t.Parallel() + + got := Host("example.com").Build() + + expected := templ.URL("//example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestSlash(t *testing.T) { + t.Parallel() + + got := Path("/").Build() + + expected := templ.URL("/") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestSlashIndex(t *testing.T) { + t.Parallel() + + got := Path("/index").Build() + + expected := templ.URL("/index") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestHTTP(t *testing.T) { + t.Parallel() + + got := Scheme("http").Host("example.com").Build() + + expected := templ.URL("http://example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestHTTPS(t *testing.T) { + t.Parallel() + + got := Scheme("https").Host("example.com").Build() + + expected := templ.URL("https://example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestMailTo(t *testing.T) { + t.Parallel() + + got := Scheme("mailto").Host("test@example.com").Build() + + expected := templ.URL("mailto:test@example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestTel(t *testing.T) { + t.Parallel() + + got := Scheme("tel").Host("+1234567890").Build() + + expected := templ.URL("tel:+1234567890") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestFtp(t *testing.T) { + t.Parallel() + + got := Scheme("ftp").Host("example.com").Build() + + expected := templ.URL("ftp://example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +} + +func TestFtps(t *testing.T) { + t.Parallel() + + got := Scheme("ftps").Host("example.com").Build() + + expected := templ.URL("ftps://example.com") + + if got != expected { + t.Fatalf("got %s, want %s", got, expected) + } +}