diff --git a/cmd/ldapfetch/main.go b/cmd/ldapfetch/main.go new file mode 100644 index 0000000..b1921ce --- /dev/null +++ b/cmd/ldapfetch/main.go @@ -0,0 +1,53 @@ +// ldapfetch is a utility to fetch and display users from an LDAP server +// based on a given LDAP json configuration file matching ConfigLDAP structure +// +// Usage: ldapfetch +package main + +import ( + "encoding/json" + "fmt" + "github.com/IMQS/authaus" + "github.com/IMQS/log" + "os" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: ldapfetch ") + os.Exit(1) + } + s, e := os.ReadFile(os.Args[1]) + if e != nil { + fmt.Println("Error reading config file:", e) + os.Exit(1) + } + var ldapConf *authaus.ConfigLDAP + e = json.Unmarshal(s, &ldapConf) + + if e != nil { + fmt.Println("Error parsing config file:", e) + os.Exit(1) + } + if !ldapConf.DebugUserPull { + fmt.Println("Warning: DebugUserPull is not enabled in the config - " + + "you may not get extra user info from LDAP") + } + + ldapImpl := authaus.NewAuthenticator_LDAP(ldapConf) + logger := log.New(log.Stdout, true) + users, e := ldapImpl.GetLdapUsers(logger) + if e != nil { + fmt.Println("Error getting ldap users:", e) + os.Exit(1) + } + fmt.Printf("%d Auth users mapped\n", len(users)) + fmt.Printf("%25v | %25v | %40v\n", + "Username", "Firstname", "Lastname") + fmt.Printf("%s\n", "------------------------------------------------------------------------------------------"+ + "------") + for _, user := range users { + fmt.Printf("%25v | %25v | %40v\n", + user.Username, user.Firstname, user.Lastname) + } +} diff --git a/db.go b/db.go index 8145a82..46212cd 100644 --- a/db.go +++ b/db.go @@ -2,6 +2,7 @@ package authaus import ( "database/sql" + "github.com/IMQS/log" "sort" "strings" "sync" @@ -134,9 +135,9 @@ type UserStore interface { // The LDAP interface allows authentication and the ability to retrieve the LDAP's users and merge them into our system type LDAP interface { - Authenticate(identity, password string) error // Return nil if the password is correct, otherwise one of ErrIdentityAuthNotFound or ErrInvalidPassword - GetLdapUsers() ([]AuthUser, error) // Retrieve the list of users from ldap - Close() // Typically used to close a database handle + Authenticate(identity, password string) error // Return nil if the password is correct, otherwise one of ErrIdentityAuthNotFound or ErrInvalidPassword + GetLdapUsers(log *log.Logger) ([]AuthUser, error) // Retrieve the list of users from ldap + Close() // Typically used to close a database handle } // A Permit database performs no validation. It simply returns the Permit owned by a particular user. @@ -338,8 +339,8 @@ func (x *sanitizingLDAP) Authenticate(identity, password string) error { return x.backend.Authenticate(identity, password) } -func (x *sanitizingLDAP) GetLdapUsers() ([]AuthUser, error) { - return x.backend.GetLdapUsers() +func (x *sanitizingLDAP) GetLdapUsers(log *log.Logger) ([]AuthUser, error) { + return x.backend.GetLdapUsers(log) } func (x *sanitizingLDAP) Close() { diff --git a/dummyLDAP.go b/dummyLDAP.go index d37346f..de4c6dc 100644 --- a/dummyLDAP.go +++ b/dummyLDAP.go @@ -1,6 +1,7 @@ package authaus import ( + "github.com/IMQS/log" "sync" ) @@ -35,7 +36,7 @@ func (x *dummyLdap) Authenticate(identity, password string) (er error) { return } -func (x *dummyLdap) GetLdapUsers() ([]AuthUser, error) { +func (x *dummyLdap) GetLdapUsers(log *log.Logger) ([]AuthUser, error) { x.usersLock.RLock() defer x.usersLock.RUnlock() //Now we build up and return the list of ldap users ([]AuthUsers) diff --git a/go.mod b/go.mod index 3afe0f0..be6d5e8 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,24 @@ module github.com/IMQS/authaus -go 1.22.7 +go 1.24.0 + +toolchain go1.24.10 require ( github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335 - github.com/IMQS/log v1.3.0 + github.com/IMQS/log v1.5.1 + github.com/go-ldap/ldap/v3 v3.4.12 github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 - github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43 - github.com/stretchr/testify v1.9.0 - github.com/wI2L/jsondiff v0.6.1 - golang.org/x/crypto v0.31.0 + github.com/stretchr/testify v1.11.1 + github.com/wI2L/jsondiff v0.7.0 + golang.org/x/crypto v0.47.0 ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index ed65c02..3db7ca8 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,39 @@ +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335 h1:n8o916boOorBHMGywZ+ucvUZRLIvjt2CaY/694CgMfU= github.com/BurntSushi/migration v0.0.0-20140125045755-c45b897f1335/go.mod h1:eVEKGm5N/F2XPdHocE3gP//Ab+rb/54WJ7XXtFGxwaQ= -github.com/IMQS/log v1.3.0 h1:3qSqHllvYd6KT7FjkzzuQ6eZfVdG+siphYTvYT6X6uA= -github.com/IMQS/log v1.3.0/go.mod h1:EVm4FzOIBh22Ucdy4n01j725B85Z7We3LaRKCVozvy8= +github.com/IMQS/log v1.5.1 h1:MrM5Cn4zUiH/cZqOd4A64sHrF+GldjN8UXOhiRKFRMc= +github.com/IMQS/log v1.5.1/go.mod h1:EVm4FzOIBh22Ucdy4n01j725B85Z7We3LaRKCVozvy8= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213 h1:3DongGRjJZvIFDq063tg76LKlGhA7O0TVqoPql0Zfbk= -github.com/mavricknz/asn1-ber v0.0.0-20151103223136-b9df1c2f4213/go.mod h1:v/ZufymxjcI3pnNmQIUQQKxnHLTblrjZ4MNLs5DrZ1o= -github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43 h1:x4SDcUPDTMzuFEdWe5lTznj1echpsd0ApTkZOdwtm7g= -github.com/mavricknz/ldap v0.0.0-20160227184754-f5a958005e43/go.mod h1:z76yvVwVulPd8FyifHe8UEHeud6XXaSan0ibi2sDy6w= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -26,10 +44,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/wI2L/jsondiff v0.6.1 h1:ISZb9oNWbP64LHnu4AUhsMF5W0FIj5Ok3Krip9Shqpw= -github.com/wI2L/jsondiff v0.6.1/go.mod h1:KAEIojdQq66oJiHhDyQez2x+sRit0vIzC9KeK0yizxM= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +github.com/wI2L/jsondiff v0.7.0 h1:1lH1G37GhBPqCfp/lrs91rf/2j3DktX6qYAKZkLuCQQ= +github.com/wI2L/jsondiff v0.7.0/go.mod h1:KAEIojdQq66oJiHhDyQez2x+sRit0vIzC9KeK0yizxM= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/ldap.go b/ldap.go index 9fc67cd..803b81b 100644 --- a/ldap.go +++ b/ldap.go @@ -4,10 +4,11 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/IMQS/log" + "github.com/go-ldap/ldap/v3" + "net" "strings" "time" - - "github.com/mavricknz/ldap" ) type LdapConnectionMode int @@ -22,6 +23,20 @@ type LdapImpl struct { config *ConfigLDAP } +func (x *LdapImpl) GetConfig() *ConfigLDAP { + return x.config +} + +type ldapEntry struct { + UserName string + GivenName string + Name string + Surname string + Email string + Mobile string + UserPrincipalName string +} + func (x *LdapImpl) Authenticate(identity, password string) error { if len(password) == 0 { // Many LDAP servers (or AD) will allow an anonymous BIND. @@ -33,8 +48,11 @@ func (x *LdapImpl) Authenticate(identity, password string) error { if err != nil { return err } - defer con.Close() - // We need to know whether or not we must add the domain to the identity by checking if it contains '@' + defer func() { + _ = con.Close() + }() + // We need to know whether we must add the domain to the identity by checking + // if it contains '@' if !strings.Contains(identity, "@") { identity = fmt.Sprintf(`%v@%v`, identity, x.config.LdapDomain) } @@ -53,8 +71,8 @@ func (x *LdapImpl) Close() { } -func (x *LdapImpl) GetLdapUsers() ([]AuthUser, error) { - var attributes []string = []string{ +func (x *LdapImpl) GetLdapUsers(log *log.Logger) ([]AuthUser, error) { + var attributes = []string{ "sAMAccountName", "givenName", "name", @@ -75,12 +93,21 @@ func (x *LdapImpl) GetLdapUsers() ([]AuthUser, error) { if err != nil { return nil, err } - defer con.Close() + defer func() { + _ = con.Close() + }() sr, err := con.SearchWithPaging(searchRequest, 100) if err != nil { + fmt.Println("LDAP search failed: ", err) return nil, err } + if x.config.DebugUserPull { + // print hierarchy by iterating over the tree, depth first + log.Infof("LDAP hierarchy:\n") + printHierarchy(extractHierarchy(sr), "", true, log) + } + getAttributeValue := func(entry ldap.Entry, attribute string) string { values := entry.GetAttributeValues(attribute) if len(values) == 0 { @@ -89,53 +116,167 @@ func (x *LdapImpl) GetLdapUsers() ([]AuthUser, error) { return values[0] } + ldapSource := make([]ldapEntry, len(sr.Entries)) + ldapUsers := make([]AuthUser, len(sr.Entries)) if x.config.DebugUserPull { - fmt.Printf("LDAP source data:\n") - fmt.Printf("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", "sAMAccountName", "givenName", "name", "sn", "mail", "mobile", "userPrincipalName") + log.Infof("%d records retrieved from LDAP server...\n", len(sr.Entries)) } - ldapUsers := make([]AuthUser, len(sr.Entries)) + allAttributes := make(map[string]struct{}) for i, value := range sr.Entries { // We trim the spaces as we have found that a certain ldap user // (WilburGS) has an email that ends with a space. - username := strings.TrimSpace(getAttributeValue(*value, "sAMAccountName")) - givenName := strings.TrimSpace(getAttributeValue(*value, "givenName")) - name := strings.TrimSpace(getAttributeValue(*value, "name")) - surname := strings.TrimSpace(getAttributeValue(*value, "sn")) - email := strings.TrimSpace(getAttributeValue(*value, "mail")) - mobile := strings.TrimSpace(getAttributeValue(*value, "mobile")) - userPrincipalName := strings.TrimSpace(getAttributeValue(*value, "userPrincipalName")) if x.config.DebugUserPull { - fmt.Printf("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", - username, givenName, name, surname, email, mobile, userPrincipalName) + log.Infof("LDAP raw entry: %+v\n", *value) + } + + for _, attr := range value.Attributes { + allAttributes[attr.Name] = struct{}{} } - if email == "" && strings.Count(userPrincipalName, "@") == 1 { + newEntry := ldapEntry{} + newEntry.UserName = strings.TrimSpace(getAttributeValue(*value, "sAMAccountName")) + newEntry.GivenName = strings.TrimSpace(getAttributeValue(*value, "givenName")) + newEntry.Name = strings.TrimSpace(getAttributeValue(*value, "name")) + newEntry.Surname = strings.TrimSpace(getAttributeValue(*value, "sn")) + newEntry.Email = strings.TrimSpace(getAttributeValue(*value, "mail")) + newEntry.Mobile = strings.TrimSpace(getAttributeValue(*value, "mobile")) + newEntry.UserPrincipalName = strings.TrimSpace(getAttributeValue(*value, "userPrincipalName")) + if newEntry.Email == "" && strings.Count(newEntry.UserPrincipalName, "@") == 1 { // This was first seen in Azure, when integrating with DTPW (Department of Transport and Public Works) - email = userPrincipalName + newEntry.Email = newEntry.UserPrincipalName } - firstName := givenName - if firstName == "" && surname == "" && name != "" { + firstName := newEntry.GivenName + if firstName == "" && newEntry.Surname == "" && newEntry.Name != "" { // We're in dubious best-guess-for-common-english territory here - firstSpace := strings.Index(name, " ") + firstSpace := strings.Index(newEntry.Name, " ") if firstSpace != -1 { - firstName = name[:firstSpace] - surname = name[firstSpace+1:] + firstName = newEntry.Name[:firstSpace] + newEntry.Surname = newEntry.Name[firstSpace+1:] } } - ldapUsers[i] = AuthUser{UserId: NullUserId, Email: email, Username: username, Firstname: firstName, Lastname: surname, Mobilenumber: mobile} + ldapSource[i] = newEntry + ldapUsers[i] = AuthUser{UserId: NullUserId, Email: newEntry.Email, Username: newEntry.UserName, Firstname: firstName, Lastname: newEntry.Surname, Mobilenumber: newEntry.Mobile} } + + // print if x.config.DebugUserPull { - fmt.Println() - fmt.Printf("Mapped to Auth users:\n") - fmt.Printf("%23v | %16v | %19v | %45v | %15v\n", "username", "firstname", "lastname", "email", "mobile") + log.Infof("All LDAP attributes seen:\n") + attributeNames := make([]string, 0, len(allAttributes)) + for attrName := range allAttributes { + attributeNames = append(attributeNames, attrName) + } + log.Infof("%v\n", strings.Join(attributeNames, ", ")) + + log.Infof("---\n") + log.Infof("LDAP source data:\n") + log.Infof("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", "sAMAccountName", "givenName", "name", "sn", "mail", "mobile", "userPrincipalName") + log.Infof("-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n") + for _, entry := range ldapSource { + log.Infof("%23v | %20v | %26v | %25v | %45v | %15v | %45v\n", + entry.UserName, entry.GivenName, entry.Name, entry.Surname, entry.Email, entry.Mobile, entry.UserPrincipalName) + } + + log.Infof("\n") + log.Infof("Mapped to Auth users:\n") + log.Infof("%23v | %16v | %19v | %45v | %15v\n", "username", "firstname", "lastname", "email", "mobile") + log.Infof("----------------------------------------------------------------------------------------------------------------------------------\n") for _, user := range ldapUsers { - fmt.Printf("%23v | %16v | %19v | %45v | %15v\n", user.Username, user.Firstname, user.Lastname, user.Email, user.Mobilenumber) + log.Infof("%23v | %16v | %19v | %45v | %15v\n", user.Username, user.Firstname, user.Lastname, user.Email, user.Mobilenumber) } } return ldapUsers, nil } +type hierarchyNode struct { + name string + nodeType string + children map[string]*hierarchyNode +} + +type pathElement struct { + name string + nodeType string +} + +// formatHierarchyPath returns the formatted path string for a node. +// If nodeType is "CN" and printNodesIfCN is true, uses colon separator. +func formatHierarchyPath(pathPrefix, nodeName, nodeType string, printNodesIfCN bool) string { + if nodeType == "CN" && printNodesIfCN { + return printTrunc(pathPrefix+" : "+nodeName, 120, "...") + } + return printTrunc(pathPrefix+"/"+nodeName, 120, "...") +} + +func printHierarchy(hierarchy *hierarchyNode, pathPrefix string, printNodesIfCN bool, log *log.Logger) { + if hierarchy == nil { + return + } + currentNode := hierarchy + // print current node + log.Infof("Node: %s\n", formatHierarchyPath(pathPrefix, currentNode.name, currentNode.nodeType, printNodesIfCN)) + + // iterate through children + pathPrefix = pathPrefix + "/" + currentNode.name + for _, child := range currentNode.children { + printHierarchy(child, pathPrefix, printNodesIfCN, log) + } +} + +// printTrunc produces an output such that len(output) <= maxLength +// +// If len(str) > maxLength, it will truncate str and append abbrevString such +// that the output is still of length maxLength. +func printTrunc(str string, maxLength int, abbrevString string) string { + if len(str) > maxLength { + return str[:maxLength-len(abbrevString)] + abbrevString + } else { + return str + } +} + +func extractHierarchy(sr *ldap.SearchResult) *hierarchyNode { + tree := &hierarchyNode{ + name: "ROOT", + nodeType: "ROOT", + children: make(map[string]*hierarchyNode), + } + + currentNode := tree + for _, a := range sr.Entries { + currentNode = tree + // New code + pathElements := extractPath(a.DN) + for _, part := range pathElements { + childNode, found := currentNode.children[part.name] + if !found { + childNode = &hierarchyNode{ + name: part.name, + nodeType: part.nodeType, + children: make(map[string]*hierarchyNode), + } + currentNode.children[part.name] = childNode + } + currentNode = childNode + } + } + return tree +} + +// extractPath extracts the top-down path elements from a DN string +func extractPath(dn string) []pathElement { + parts := strings.Split(dn, ",") + result := make([]pathElement, 0, len(parts)) + for i := len(parts) - 1; i >= 0; i-- { + part := strings.TrimSpace(parts[i]) + equalIndex := strings.Index(part, "=") + if equalIndex != -1 && equalIndex < len(part)-1 { + result = append(result, pathElement{part[equalIndex+1:], part[:equalIndex]}) + } + } + return result +} + func MergeLDAP(c *Central) { - ldapUsers, err := c.ldap.GetLdapUsers() + ldapUsers, err := c.ldap.GetLdapUsers(c.Log) if err != nil { c.Log.Warnf("Failed to retrieve users from LDAP server for merge to take place (%v)", err) return @@ -148,7 +289,9 @@ func MergeLDAP(c *Central) { MergeLdapUsersIntoLocalUserStore(c, ldapUsers, imqsUsers) } -// We are reading users from LDAP/AD and merging them into the IMQS userstore +// MergeLdapUsersIntoLocalUserStore +// +// Reads users from LDAP/AD and merges them into the IMQS user store func MergeLdapUsersIntoLocalUserStore(x *Central, ldapUsers []AuthUser, imqsUsers []AuthUser) { // Create maps from arrays imqsUserUsernameMap := make(map[string]AuthUser) @@ -268,41 +411,77 @@ func equalsForLDAPMerge(a, b AuthUser) bool { a.Username == b.Username } -func NewLDAPConnectAndBind(config *ConfigLDAP) (*ldap.LDAPConnection, error) { +func NewLDAPConnectAndBind(config *ConfigLDAP) (*ldap.Conn, error) { con, err := NewLDAPConnect(config) if err != nil { return nil, err } - if err := con.Bind(config.LdapUsername, config.LdapPassword); err != nil { + if err = con.Bind(config.LdapUsername, config.LdapPassword); err != nil { return nil, err } return con, nil } -func NewLDAPConnect(config *ConfigLDAP) (*ldap.LDAPConnection, error) { - con := ldap.NewLDAPConnection(config.LdapHost, config.LdapPort) - con.NetworkConnectTimeout = 30 * time.Second - con.ReadTimeout = 30 * time.Second +// NewLDAPConnect creates a new LDAP connection based on the configuration +// provided +// +// If the ldapPort is not specified, it defaults to 389. +// If connection is not-null, the calling function is required to close the +// connection when done with it. +func NewLDAPConnect(config *ConfigLDAP) (*ldap.Conn, error) { + if config.LdapPort == 0 { + config.LdapPort = 389 + } + + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + } + ldapMode, legalLdapMode := configLdapNameToMode[config.Encryption] if !legalLdapMode { return nil, errors.New(fmt.Sprintf("Unknown ldap mode %v. Recognized modes are TLS, SSL, and empty for unencrypted", config.Encryption)) } + + // TODO : Switch protocol explicitly for SSL mode to ldaps:// + addr := "ldap://" + config.LdapHost + fmt.Sprintf(":%d", config.LdapPort) + switch ldapMode { case LdapConnectionModePlainText: + c, e := ldap.DialURL(addr, ldap.DialWithDialer(dialer)) + if e != nil { + return nil, e + } + c.SetTimeout(10 * time.Second) + return c, e + // DEPRECATED case LdapConnectionModeSSL: - con.IsSSL = true + // Use ldaps:// protocol and default port 636 if not specified + sslPort := config.LdapPort + if sslPort == 0 || sslPort == 389 { + sslPort = 636 + } + ldapsAddr := "ldaps://" + config.LdapHost + fmt.Sprintf(":%d", sslPort) + tlsConfig := &tls.Config{InsecureSkipVerify: config.InsecureSkipVerify} + c, e := ldap.DialURL(ldapsAddr, ldap.DialWithDialer(dialer), ldap.DialWithTLSConfig(tlsConfig)) + if e != nil { + return nil, e + } + c.SetTimeout(10 * time.Second) + return c, e case LdapConnectionModeTLS: - con.IsTLS = true - } - if config.InsecureSkipVerify { - con.TlsConfig = &tls.Config{} - con.TlsConfig.InsecureSkipVerify = config.InsecureSkipVerify - } - if err := con.Connect(); err != nil { - con.Close() - return nil, err + tlsConfig := &tls.Config{} + if config.InsecureSkipVerify { + tlsConfig.InsecureSkipVerify = config.InsecureSkipVerify + } + c, e := ldap.DialURL(addr, ldap.DialWithDialer(dialer), ldap.DialWithTLSConfig(tlsConfig)) + if e != nil { + return nil, e + } + c.SetTimeout(10 * time.Second) + return c, e + default: + return nil, errors.New("unimplemented LDAP connection mode") } - return con, nil } func NewAuthenticator_LDAP(config *ConfigLDAP) *LdapImpl { diff --git a/ldap_test.go b/ldap_test.go index eabdd01..6eb0e6e 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -1,7 +1,10 @@ package authaus import ( + "github.com/IMQS/log" + "github.com/go-ldap/ldap/v3" "github.com/stretchr/testify/assert" + "strings" "testing" "time" ) @@ -101,3 +104,267 @@ func TestLDAPUserDiffDiff(t *testing.T) { } t.Logf("User diff: \n%v", diff) } + +func Test_printTrunc(t *testing.T) { + type args struct { + name string + maxLength int + abbrevString string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "Short string no trunc", + args: args{"This is a test string", 25, "..."}, + want: "This is a test string", + }, + { + name: "Short string limit", + args: args{"This is a test string xxx", 25, "..."}, + want: "This is a test string xxx", + }, + { + name: "Over by 1", + args: args{"This is a test string x y z", 25, "..."}, + want: "This is a test string ...", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, printTrunc(tt.args.name, tt.args.maxLength, tt.args.abbrevString), "printTrunc(%v, %v, %v)", tt.args.name, tt.args.maxLength, tt.args.abbrevString) + }) + } +} + +func Test_extractPath(t *testing.T) { + tests := []struct { + name string + dn string + want []pathElement + }{ + { + name: "empty DN", + dn: "", + want: []pathElement{}, + }, + { + name: "single DC component", + dn: "DC=com", + want: []pathElement{{name: "com", nodeType: "DC"}}, + }, + { + name: "normal DN", + dn: "CN=John Doe,OU=Users,DC=example,DC=com", + want: []pathElement{ + {name: "com", nodeType: "DC"}, + {name: "example", nodeType: "DC"}, + {name: "Users", nodeType: "OU"}, + {name: "John Doe", nodeType: "CN"}, + }, + }, + { + name: "DN with spaces around separators", + dn: "CN=Jane , OU=Staff , DC=corp , DC=org", + want: []pathElement{ + {name: "org", nodeType: "DC"}, + {name: "corp", nodeType: "DC"}, + {name: "Staff", nodeType: "OU"}, + {name: "Jane", nodeType: "CN"}, + }, + }, + { + name: "malformed element with no equals sign", + dn: "nodns", + want: []pathElement{}, + }, + { + name: "element with equals at end (empty value) is filtered out", + dn: "DC=com,OU=", + want: []pathElement{{name: "com", nodeType: "DC"}}, + }, + { + name: "value containing an equals sign", + dn: "CN=John=Doe,DC=com", + want: []pathElement{ + {name: "com", nodeType: "DC"}, + {name: "John=Doe", nodeType: "CN"}, + }, + }, + { + name: "mixed valid and malformed elements", + dn: "CN=Alice,badpart,DC=example", + want: []pathElement{ + {name: "example", nodeType: "DC"}, + {name: "Alice", nodeType: "CN"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractPath(tt.dn) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_extractHierarchy(t *testing.T) { + t.Run("empty search result", func(t *testing.T) { + sr := &ldap.SearchResult{} + root := extractHierarchy(sr) + assert.NotNil(t, root) + assert.Equal(t, "ROOT", root.name) + assert.Equal(t, "ROOT", root.nodeType) + assert.Empty(t, root.children) + }) + + t.Run("single entry builds correct tree", func(t *testing.T) { + sr := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + {DN: "CN=John Doe,OU=Users,DC=example,DC=com"}, + }, + } + root := extractHierarchy(sr) + assert.NotNil(t, root) + + // Expect: ROOT → com → example → Users → John Doe + comNode, ok := root.children["com"] + assert.True(t, ok, "expected 'com' child under ROOT") + assert.Equal(t, "DC", comNode.nodeType) + + exampleNode, ok := comNode.children["example"] + assert.True(t, ok, "expected 'example' child under 'com'") + assert.Equal(t, "DC", exampleNode.nodeType) + + usersNode, ok := exampleNode.children["Users"] + assert.True(t, ok, "expected 'Users' child under 'example'") + assert.Equal(t, "OU", usersNode.nodeType) + + johnNode, ok := usersNode.children["John Doe"] + assert.True(t, ok, "expected 'John Doe' child under 'Users'") + assert.Equal(t, "CN", johnNode.nodeType) + assert.Empty(t, johnNode.children) + }) + + t.Run("multiple entries share common path components", func(t *testing.T) { + sr := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + {DN: "CN=Alice,OU=Users,DC=example,DC=com"}, + {DN: "CN=Bob,OU=Users,DC=example,DC=com"}, + {DN: "CN=Carol,OU=Admins,DC=example,DC=com"}, + }, + } + root := extractHierarchy(sr) + + comNode := root.children["com"] + assert.NotNil(t, comNode) + exampleNode := comNode.children["example"] + assert.NotNil(t, exampleNode) + + // Both OUs should be present + assert.Len(t, exampleNode.children, 2) + + usersNode := exampleNode.children["Users"] + assert.NotNil(t, usersNode) + assert.Contains(t, usersNode.children, "Alice") + assert.Contains(t, usersNode.children, "Bob") + + adminsNode := exampleNode.children["Admins"] + assert.NotNil(t, adminsNode) + assert.Contains(t, adminsNode.children, "Carol") + }) + + t.Run("entry with malformed DN produces no children", func(t *testing.T) { + sr := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + {DN: "nodns"}, + }, + } + root := extractHierarchy(sr) + assert.NotNil(t, root) + assert.Empty(t, root.children) + }) +} + +func Test_printHierarchy(t *testing.T) { + logger := log.NewTesting(t) + + t.Run("nil node does not panic", func(t *testing.T) { + // This test checks that calling printHierarchy with a nil node does not panic. + assert.NotPanics(t, func() { + printHierarchy(nil, "", false, logger) + }) + // No output is asserted, as formatHierarchyPath is not called for nil. + }) + + t.Run("single node with OU type formats slash-separated path", func(t *testing.T) { + logger := log.NewTesting(t) + node := &hierarchyNode{ + name: "Users", + nodeType: "OU", + children: make(map[string]*hierarchyNode), + } + assert.NotPanics(t, func() { + printHierarchy(node, "ROOT", false, logger) + }) + path := formatHierarchyPath("ROOT", node.name, node.nodeType, false) + assert.Equal(t, "ROOT/Users", path) + }) + + t.Run("CN node with printNodesIfCN=true formats colon-separated path", func(t *testing.T) { + node := &hierarchyNode{ + name: "John Doe", + nodeType: "CN", + children: make(map[string]*hierarchyNode), + } + assert.NotPanics(t, func() { + printHierarchy(node, "ROOT/example/Users", true, logger) + }) + path := formatHierarchyPath("ROOT/example/Users", node.name, node.nodeType, true) + assert.Equal(t, "ROOT/example/Users : John Doe", path) + }) + + t.Run("tree with children recurses without panic", func(t *testing.T) { + child := &hierarchyNode{ + name: "John Doe", + nodeType: "CN", + children: make(map[string]*hierarchyNode), + } + parent := &hierarchyNode{ + name: "Users", + nodeType: "OU", + children: map[string]*hierarchyNode{"John Doe": child}, + } + assert.NotPanics(t, func() { + printHierarchy(parent, "ROOT", false, logger) + }) + }) +} + +func Test_formatHierarchyPath(t *testing.T) { + t.Run("OU node returns slash-separated path", func(t *testing.T) { + got := formatHierarchyPath("ROOT", "Users", "OU", false) + assert.Contains(t, got, "ROOT/Users") + }) + + t.Run("CN node with printNodesIfCN=false returns slash-separated path", func(t *testing.T) { + got := formatHierarchyPath("ROOT/Users", "John Doe", "CN", false) + assert.Contains(t, got, "ROOT/Users/John Doe") + }) + + t.Run("CN node with printNodesIfCN=true returns colon-separated path", func(t *testing.T) { + got := formatHierarchyPath("ROOT/Users", "John Doe", "CN", true) + assert.Contains(t, got, "ROOT/Users : John Doe") + }) + + // Ensure the generated path exceeds 120 characters for truncation + t.Run("truncation works for long paths", func(t *testing.T) { + prefix := strings.Repeat("a", 130) + longNode := strings.Repeat("b", 20) + got := formatHierarchyPath(prefix, longNode, "OU", false) + assert.True(t, len(got) <= 120, "Output should be truncated to 120 characters or less") + assert.True(t, strings.HasSuffix(got, "..."), "Output should end with '...'") + }) +} diff --git a/msaad.go b/msaad.go index ffe6af7..08731cd 100644 --- a/msaad.go +++ b/msaad.go @@ -689,7 +689,7 @@ func (m *MSAAD) populateAADRoles(users []*msaadUser) error { defer func() { if r := recover(); r != nil { s := GetStack() - errGlobal = fmt.Errorf(fmt.Sprintf("%v\n%v\n", r, s)) + errGlobal = fmt.Errorf("%v\n%v\n", r, s) } }() for _, user := range threadGroup {