Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.DS_Store
79 changes: 67 additions & 12 deletions urlbuilder/urlbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,57 @@ import (
type URLBuilder struct {
scheme string
host string
path []string
path []val
query url.Values
fragment string
}

type val struct {
val string
shouldEscape bool
}

// New creates a new URLBuilder with the given scheme and host
func New(scheme string, host string) *URLBuilder {
func New() *URLBuilder {
return &URLBuilder{
scheme: scheme,
host: host,
query: url.Values{},
query: make(url.Values),
}
}

// Scheme creates a new URLBuilder with the given scheme
func Scheme(scheme string) *URLBuilder {
return New().Scheme(scheme)
}

// Host creates a new URLBuilder with the given host
// The scheme will be protocol-relative
func Host(host string) *URLBuilder {
return New().Host(host)
}

// Path creates a new URLBuilder with the given path segment
// This path will *not* be escaped
func Path(segment string) *URLBuilder {
ub := New()
ub.path = []val{{val: segment, shouldEscape: false}}
return ub
}

// Scheme sets the scheme of the URL
func (ub *URLBuilder) Scheme(scheme string) *URLBuilder {
ub.scheme = scheme
return ub
}

// Host sets the host of the URL
func (ub *URLBuilder) Host(host string) *URLBuilder {
ub.host = host
return ub
}

// Path adds a path segment to the URL
func (ub *URLBuilder) Path(segment string) *URLBuilder {
ub.path = append(ub.path, segment)
ub.path = append(ub.path, val{val: segment, shouldEscape: true})
return ub
}

Expand All @@ -46,13 +80,34 @@ func (ub *URLBuilder) Fragment(fragment string) *URLBuilder {
// Build constructs the final URL as a SafeURL
func (ub *URLBuilder) Build() templ.SafeURL {
var buf strings.Builder
buf.WriteString(ub.scheme)
buf.WriteString("://")
buf.WriteString(ub.host)
switch ub.scheme {
case "tel", "mailto":
buf.WriteString(ub.scheme)
buf.WriteByte(':')
buf.WriteString(ub.host)
return templ.SafeURL(buf.String())
default:
if ub.scheme != "" {
buf.WriteString(ub.scheme)
buf.WriteByte(':')
}
}

if ub.host != "" {
buf.WriteString("//")
buf.WriteString(ub.host)
}

for _, segment := range ub.path {
buf.WriteByte('/')
buf.WriteString(url.PathEscape(segment))

if !strings.HasPrefix(segment.val, "/") {
buf.WriteByte('/')
}
if segment.shouldEscape {
buf.WriteString(url.PathEscape(segment.val))
} else {
buf.WriteString(segment.val)
}
}

if len(ub.query) > 0 {
Expand All @@ -65,5 +120,5 @@ func (ub *URLBuilder) Build() templ.SafeURL {
buf.WriteString(url.QueryEscape(ub.fragment))
}

return templ.URL(buf.String())
return templ.SafeURL((buf.String()))
}
147 changes: 139 additions & 8 deletions urlbuilder/urlbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ func BenchmarkURLBuilder(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {
New("https", "example.com").
New().
Scheme("https").
Host("example.com").
Path("a").
Path("b").
Path("c").
Expand All @@ -28,7 +30,8 @@ func BenchmarkURLBuilder(b *testing.B) {
func TestBasicURL(t *testing.T) {
t.Parallel()

got := New("https", "example.com").
got := Scheme("https").
Host("example.com").
Build()

expected := templ.URL("https://example.com")
Expand All @@ -42,7 +45,8 @@ func TestURLWithPaths(t *testing.T) {
t.Parallel()

c := "c"
got := New("https", "example.com").
got := Scheme("https").
Host("example.com").
Path("a").
Path("b").
Path(c).
Expand All @@ -58,7 +62,8 @@ func TestURLWithPaths(t *testing.T) {
func TestURLWithMultipleQueries(t *testing.T) {
t.Parallel()

got := New("https", "example.com").
got := Scheme("https").
Host("example.com").
Path("path").
Query("key1", "value1").
Query("key2", "value2").
Expand All @@ -74,11 +79,12 @@ func TestURLWithMultipleQueries(t *testing.T) {
func TestURLWithNoPaths(t *testing.T) {
t.Parallel()

got := New("http", "example.org").
got := Scheme("https").
Host("example.com").
Query("search", "golang").
Build()

expected := templ.URL("http://example.org?search=golang")
expected := templ.URL("https://example.com?search=golang")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
Expand All @@ -88,7 +94,8 @@ func TestURLWithNoPaths(t *testing.T) {
func TestURLEscapingPath(t *testing.T) {
t.Parallel()

got := New("https", "example.com").
got := Scheme("https").
Host("example.com").
Path("a/b").
Path("c d").
Build()
Expand All @@ -103,7 +110,8 @@ func TestURLEscapingPath(t *testing.T) {
func TestURLEscapingQuery(t *testing.T) {
t.Parallel()

got := New("https", "example.com").
got := Scheme("https").
Host("example.com").
Query("key with space", "value with space").
Query("key/with/slash", "value/with/slash").
Build()
Expand All @@ -114,3 +122,126 @@ func TestURLEscapingQuery(t *testing.T) {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestPath(t *testing.T) {
t.Parallel()

got := Path("chat").
Path("response").
Query("input", "hello!").
Build()

expected := templ.URL("/chat/response?input=hello%21")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestProtocolRelative(t *testing.T) {
t.Parallel()

got := Host("example.com").Build()

expected := templ.URL("//example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestSlash(t *testing.T) {
t.Parallel()

got := Path("/").Build()

expected := templ.URL("/")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestSlashIndex(t *testing.T) {
t.Parallel()

got := Path("/index").Build()

expected := templ.URL("/index")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestHTTP(t *testing.T) {
t.Parallel()

got := Scheme("http").Host("example.com").Build()

expected := templ.URL("http://example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestHTTPS(t *testing.T) {
t.Parallel()

got := Scheme("https").Host("example.com").Build()

expected := templ.URL("https://example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestMailTo(t *testing.T) {
t.Parallel()

got := Scheme("mailto").Host("test@example.com").Build()

expected := templ.URL("mailto:test@example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestTel(t *testing.T) {
t.Parallel()

got := Scheme("tel").Host("+1234567890").Build()

expected := templ.URL("tel:+1234567890")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestFtp(t *testing.T) {
t.Parallel()

got := Scheme("ftp").Host("example.com").Build()

expected := templ.URL("ftp://example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}

func TestFtps(t *testing.T) {
t.Parallel()

got := Scheme("ftps").Host("example.com").Build()

expected := templ.URL("ftps://example.com")

if got != expected {
t.Fatalf("got %s, want %s", got, expected)
}
}