Skip to content
This repository was archived by the owner on Jul 26, 2024. It is now read-only.

Commit 86a1b3c

Browse files
committed
Add context support
1 parent 0784ece commit 86a1b3c

2 files changed

Lines changed: 93 additions & 2 deletions

File tree

message.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gomail
22

33
import (
44
"bytes"
5+
"context"
56
"io"
67
"os"
78
"path/filepath"
@@ -19,6 +20,7 @@ type Message struct {
1920
hEncoder mimeEncoder
2021
buf bytes.Buffer
2122
boundary string
23+
ctx context.Context
2224
}
2325

2426
type header map[string][]string
@@ -36,6 +38,7 @@ func NewMessage(settings ...MessageSetting) *Message {
3638
header: make(header),
3739
charset: "UTF-8",
3840
encoding: QuotedPrintable,
41+
ctx: context.Background(),
3942
}
4043

4144
m.applySettings(settings)
@@ -60,6 +63,19 @@ func (m *Message) Reset() {
6063
m.embedded = nil
6164
}
6265

66+
// Context returns the message's internal context. The context is either set
67+
// using SetContext or it's defaulted to Background.
68+
func (m *Message) Context() context.Context {
69+
return m.ctx
70+
}
71+
72+
// WithContext copies the message and makes it use a different context.
73+
func (m *Message) WithContext(ctx context.Context) *Message {
74+
m2 := *m
75+
m2.ctx = ctx
76+
return &m2
77+
}
78+
6379
func (m *Message) applySettings(settings []MessageSetting) {
6480
for _, s := range settings {
6581
s(m)
@@ -84,6 +100,15 @@ func SetEncoding(enc Encoding) MessageSetting {
84100
}
85101
}
86102

103+
// SetContext is a message setting to set the context of the email. The context
104+
// determines cancellation and timeout for sending the message over the SMTP
105+
// connection.
106+
func SetContext(ctx context.Context) MessageSetting {
107+
return func(m *Message) {
108+
m.ctx = ctx
109+
}
110+
}
111+
87112
// Encoding represents a MIME encoding scheme like quoted-printable or base64.
88113
type Encoding string
89114

smtp.go

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gomail
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"io"
@@ -26,6 +27,9 @@ type Dialer struct {
2627
Auth smtp.Auth
2728
// Port represents the port of the SMTP server.
2829
Port int
30+
// NetDialer is the net.Dialer instance to use. For legacy purposes, if
31+
// NetDialTimeout is not net.DialTimeout, then this field is not used.
32+
NetDialer net.Dialer
2933
// TLSConfig represents the TLS configuration used for the TLS (when the
3034
// STARTTLS extension is used) or SSL connection.
3135
TLSConfig *tls.Config
@@ -74,12 +78,30 @@ func NewPlainDialer(host string, port int, username, password string) *Dialer {
7478
// NetDialTimeout specifies the DialTimeout function to establish a connection
7579
// to the SMTP server. This can be used to override dialing in the case that a
7680
// proxy or other special behavior is needed.
81+
//
82+
// Deprecated: use (*Dialer).NetDialer instead. If NetDialTimeout is nil, then
83+
// (*Dialer).NetDialer is used.
7784
var NetDialTimeout = net.DialTimeout
7885

7986
// Dial dials and authenticates to an SMTP server. The returned SendCloser
8087
// should be closed when done using it.
8188
func (d *Dialer) Dial() (SendCloser, error) {
82-
conn, err := NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout)
89+
return d.DialCtx(context.Background())
90+
}
91+
92+
// DialCtx is Dial with context support.
93+
func (d *Dialer) DialCtx(ctx context.Context) (SendCloser, error) {
94+
var conn net.Conn
95+
var err error
96+
97+
if NetDialTimeout == nil {
98+
ctx, cancel := context.WithTimeout(ctx, d.Timeout)
99+
defer cancel()
100+
101+
conn, err = d.NetDialer.DialContext(ctx, "tcp", addr(d.Host, d.Port))
102+
} else {
103+
conn, err = NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout)
104+
}
83105
if err != nil {
84106
return nil, err
85107
}
@@ -93,6 +115,19 @@ func (d *Dialer) Dial() (SendCloser, error) {
93115
return nil, err
94116
}
95117

118+
doneCtx, cancel := context.WithCancel(ctx)
119+
defer cancel()
120+
121+
go func() {
122+
select {
123+
case <-ctx.Done():
124+
// Parent context expired. Immediately terminate and return.
125+
c.Close()
126+
case <-doneCtx.Done():
127+
// ok
128+
}
129+
}()
130+
96131
if d.Timeout > 0 {
97132
conn.SetDeadline(time.Now().Add(d.Timeout))
98133
}
@@ -201,12 +236,21 @@ func addr(host string, port int) string {
201236
// DialAndSend opens a connection to the SMTP server, sends the given emails and
202237
// closes the connection.
203238
func (d *Dialer) DialAndSend(m ...*Message) error {
204-
s, err := d.Dial()
239+
return d.DialAndSendCtx(context.Background(), m...)
240+
}
241+
242+
// DialAndSendCtx is DialAndSend with context support.
243+
func (d *Dialer) DialAndSendCtx(ctx context.Context, m ...*Message) error {
244+
s, err := d.DialCtx(ctx)
205245
if err != nil {
206246
return err
207247
}
208248
defer s.Close()
209249

250+
for i := range m {
251+
m[i] = m[i].WithContext(ctx)
252+
}
253+
210254
return Send(s, m...)
211255
}
212256

@@ -228,7 +272,29 @@ func (c *smtpSender) retryError(err error) bool {
228272
return err == io.EOF
229273
}
230274

275+
type messageContexter interface {
276+
Context() context.Context
277+
}
278+
279+
var _ messageContexter = (*Message)(nil)
280+
231281
func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error {
282+
if ctxer, ok := msg.(messageContexter); ok {
283+
if ctx := ctxer.Context(); ctx != context.Background() {
284+
doneCtx, cancel := context.WithCancel(ctx)
285+
defer cancel()
286+
287+
go func() {
288+
select {
289+
case <-ctx.Done():
290+
c.conn.Close()
291+
case <-doneCtx.Done():
292+
// ok
293+
}
294+
}()
295+
}
296+
}
297+
232298
if c.d.Timeout > 0 {
233299
c.conn.SetDeadline(time.Now().Add(c.d.Timeout))
234300
}

0 commit comments

Comments
 (0)