From f4e67fa4cd924fbe6f271611514caf5589e6a6e5 Mon Sep 17 00:00:00 2001 From: ned Date: Wed, 12 Nov 2014 14:52:16 -0700 Subject: LDAP server support --- README | 33 -- README.md | 113 +++++++ bind.go | 2 +- conn.go | 3 +- control.go | 2 +- debug.go | 2 +- examples/cert_DONOTUSE.pem | 18 ++ examples/key_DONOTUSE.pem | 27 ++ examples/modify.go | 2 +- examples/proxy.go | 106 +++++++ examples/search.go | 2 +- examples/searchSSL.go | 2 +- examples/searchTLS.go | 2 +- examples/server.go | 64 ++++ filter.go | 161 +++++++++- filter_test.go | 4 +- ldap.go | 8 +- modify.go | 2 +- search.go | 2 +- server.go | 595 ++++++++++++++++++++++++++++++++++++ server_test.go | 732 +++++++++++++++++++++++++++++++++++++++++++++ 21 files changed, 1821 insertions(+), 61 deletions(-) delete mode 100644 README create mode 100644 README.md create mode 100644 examples/cert_DONOTUSE.pem create mode 100644 examples/key_DONOTUSE.pem create mode 100644 examples/proxy.go create mode 100644 examples/server.go create mode 100644 server.go create mode 100644 server_test.go diff --git a/README b/README deleted file mode 100644 index 8987bcb..0000000 --- a/README +++ /dev/null @@ -1,33 +0,0 @@ -Basic LDAP v3 functionality for the GO programming language. - -Required Librarys: - github.com/vanackere/asn1-ber - -Working: - Connecting to LDAP server - Binding to LDAP server - Searching for entries - Compiling string filters to LDAP filters - Paging Search Results - Modify Requests / Responses - -Examples: - search - modify - -Tests Implemented: - Filter Compile / Decompile - -TODO: - Add Requests / Responses - Delete Requests / Responses - Modify DN Requests / Responses - Compare Requests / Responses - Implement Tests / Benchmarks - -This feature is disabled at the moment, because in some cases the "Search Request Done" packet will be handled before the last "Search Request Entry": - Mulitple internal goroutines to handle network traffic - Makes library goroutine safe - Can perform multiple search requests at the same time and return - the results to the proper goroutine. All requests are blocking - requests, so the goroutine does not need special handling diff --git a/README.md b/README.md new file mode 100644 index 0000000..c72fca8 --- /dev/null +++ b/README.md @@ -0,0 +1,113 @@ +# LDAP for Golang + +This library provides basic LDAP v3 functionality for the GO programming language. + +The **client** portion is limited, but sufficient to perform LDAP authentication and directory lookups (binds and searches) against any modern LDAP server (tested with OpenLDAP and AD). + +The **server** portion implements Bind and Search from [RFC4510](http://tools.ietf.org/html/rfc4510), has good testing coverage, and is compatible with any LDAPv3 client. It provides the building blocks for a custom LDAP server, but you must implement the backend datastore of your choice. + + +## LDAP client notes: + +### A simple LDAP bind operation: +```go +l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) +// be sure to add error checking! +defer l.Close() +err = l.Bind(user, passwd) +if err==nil { + // authenticated +} else { + // invalid authentication +} +``` + +### A simple LDAP search operation: +```go +search := &SearchRequest{ + BaseDN: "dc=example,dc=com", + Filter: "(objectclass=*)", +} +searchResults, err := l.Search(search) +// be sure to add error checking! +``` + +### Implemented: +* Connecting, binding to LDAP server +* Searching for entries with filtering and paging controls +* Compiling string filters to LDAP filters +* Modify Requests / Responses + +### Not implemented: +* Add, Delete, Modify DN, Compare operations +* Most tests / benchmarks + +### LDAP client examples: +* examples/search.go: **Basic client bind and search** +* examples/searchSSL.go: **Client bind and search over SSL** +* examples/searchTLS.go: **Client bind and search over TLS** +* examples/modify.go: **Client modify operation** + +*Client library by: [mmitton](https://github.com/mmitton), with contributions from: [uavila](https://github.com/uavila), [vanackere](https://github.com/vanackere), [juju2013](https://github.com/juju2013), [johnweldon](https://github.com/johnweldon), [marcsauter](https://github.com/marcsauter), and [nmcclain](https://github.com/nmcclain)* + +## LDAP server notes: +The server library is modeled after net/http - you designate handlers for the LDAP operations you want to support (Bind/Search/etc.), then start the server with ListenAndServe(). You can specify different handlers for different baseDNs - they must implement the interfaces of the operations you want to support: +```go +type Binder interface { + Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) +} +type Searcher interface { + Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) +} +type Closer interface { + Close(conn net.Conn) error +} +``` + +### A basic bind-only LDAP server +```go +func main() { + s := ldap.NewServer() + handler := ldapHandler{} + s.BindFunc("", handler) + if err := s.ListenAndServe("localhost:389"); err != nil { + log.Fatal("LDAP Server Failed: %s", err.Error()) + } +} +type ldapHandler struct { +} +func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + if bindDN == "" && bindSimplePw == "" { + return ldap.LDAPResultSuccess, nil + } + return ldap.LDAPResultInvalidCredentials, nil +} +``` + +* Server.EnforceLDAP: Normally, the LDAP server will return whatever results your handler provides. Set the **Server.EnforceLDAP** flag to **true** and the server will apply the LDAP **search filter**, **attributes limits**, **size/time limits**, **search scope**, and **base DN matching** to your handler's dataset. This makes it a lot simpler to write a custom LDAP server without worrying about LDAP internals. + +### LDAP server examples: +* examples/server.go: **Basic LDAP authentication (bind and search only)** +* examples/proxy.go: **Simple LDAP proxy server.** +* server_test: **The tests have examples of all server functions.** + +*Warning: Do not use the example SSL certificates in production!* + +### Known limitations: + +* Golang's TLS implementation does not support SSLv2. Some old OSs require SSLv2, and are not able to connect to an LDAP server created with this library's ListenAndServeTLS() function. If you *must* support legacy (read: *insecure*) SSLv2 clients, run your LDAP server behind HAProxy. + +### Not implemented: +All of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**: +* 4.1.11. Controls +* 4.5.1.3. SearchRequest.derefAliases +* 4.5.1.5. SearchRequest.timeLimit +* 4.5.1.6. SearchRequest.typesOnly +* 4.6. Modify Operation +* 4.7. Add Operation +* 4.8. Delete Operation +* 4.9. Modify DN Operation +* 4.10. Compare Operation +* 4.14. StartTLS Operation + +*Server library by: [nmcclain](https://github.com/nmcclain)* diff --git a/bind.go b/bind.go index c3ba8ce..171a2e9 100644 --- a/bind.go +++ b/bind.go @@ -7,7 +7,7 @@ package ldap import ( "errors" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) func (l *Conn) Bind(username, password string) error { diff --git a/conn.go b/conn.go index ce0370d..9f40317 100644 --- a/conn.go +++ b/conn.go @@ -11,8 +11,7 @@ import ( "net" "sync" "time" - - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) const ( diff --git a/control.go b/control.go index e8022ba..dd46fea 100644 --- a/control.go +++ b/control.go @@ -7,7 +7,7 @@ package ldap import ( "fmt" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) const ( diff --git a/debug.go b/debug.go index e6edfd4..de9bc5a 100644 --- a/debug.go +++ b/debug.go @@ -3,7 +3,7 @@ package ldap import ( "log" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) // debbuging type diff --git a/examples/cert_DONOTUSE.pem b/examples/cert_DONOTUSE.pem new file mode 100644 index 0000000..ee14324 --- /dev/null +++ b/examples/cert_DONOTUSE.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC9jCCAeCgAwIBAgIRAOG6xrSjAWQvJl9xFTIte/owCwYJKoZIhvcNAQELMBIx +EDAOBgNVBAoTB0FjbWUgQ28wHhcNMTQwODIwMTY1MjQ4WhcNMTUwODIwMTY1MjQ4 +WjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEA30gcjawL7RXQ5B5IfAcCPsJkG3GBbfhkbBRI22VxktBoNvqh2TWyECG3 +WsB/N1WMmATnLxamBZ5mfouNbd120gbO1M06Ti57NP1YTmMp8AU18Dm4OjZ6IeQf +ip1xYSSSb6UyucFN6zIt+5PY2o4DoGb6fSNKb1ybgu91LmC1O/TDlyYUWn2TtF73 +FOUwSt+A6t3/Jhjhlp4n5Oobw1rrAgf7DPhWFg0Thj1yknPzWALY2LPREOMWob0D +EgR5C3WS2eYPyHkeMZWoSY6BiWTIU+hFqQUkdOvrWhflFoiZIsOl6iXmQpo2EQlg +j3Oy2zyZk1ndAfHlFoAgPIIbnBc+2QIDAQABo0swSTAOBgNVHQ8BAf8EBAMCAKAw +EwYDVR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAUBgNVHREEDTALggls +b2NhbGhvc3QwCwYJKoZIhvcNAQELA4IBAQBB9xNt3rDrBA9tCLCjdlnIQuUu9Uf0 +tHsSH6keBkhEoAylzHjmkNlerhTaLkRgB0D8qjE5+1APz42TuRpHRunYHSTNN0aF +N6zKlpXS0g+J/ViCh/Zw7xQI4mpSFqYzTgn4T733FqwLrmtKsj0IOOkDYSZc7qfh +qwXp/SB1J0Kp8G8S3G73dCZZYuW8y/eYMEoSkjNwNLAXzEAmFkGd8f1xhWTvnOxz +ZBbOOjggdRLxr7cMZ8GaVWFgEG93y3AYMhFxZYRwWTcWJvSTNP3xC/CWqxXkiKdO +2BROqmTw8zdqjXCIbgX4B5G5njMq9fk0gc4SiTAQkCOF6Xo0wQUvBAbN +-----END CERTIFICATE----- diff --git a/examples/key_DONOTUSE.pem b/examples/key_DONOTUSE.pem new file mode 100644 index 0000000..7feaa11 --- /dev/null +++ b/examples/key_DONOTUSE.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEoQIBAAKCAQEA30gcjawL7RXQ5B5IfAcCPsJkG3GBbfhkbBRI22VxktBoNvqh +2TWyECG3WsB/N1WMmATnLxamBZ5mfouNbd120gbO1M06Ti57NP1YTmMp8AU18Dm4 +OjZ6IeQfip1xYSSSb6UyucFN6zIt+5PY2o4DoGb6fSNKb1ybgu91LmC1O/TDlyYU +Wn2TtF73FOUwSt+A6t3/Jhjhlp4n5Oobw1rrAgf7DPhWFg0Thj1yknPzWALY2LPR +EOMWob0DEgR5C3WS2eYPyHkeMZWoSY6BiWTIU+hFqQUkdOvrWhflFoiZIsOl6iXm +Qpo2EQlgj3Oy2zyZk1ndAfHlFoAgPIIbnBc+2QIDAQABAoIBAGXECi+QEMd4QAMY +wlS1JRLRqqrPavxiT/Lqs+I7NC6EClu0k/vZ+1Ra6aTVQ6ZGuZO3+F5/5h99eJ2I +oWdHnxZOwApBl6d2i/U02wCvNbgNx+27gPoXRkcYIEAfTkPGVW/JTXtYXVkrP8YA +NsA2JfT/un86jHyBKufcl/4RWcj/BddduEWZRAV7UH4iUVzr6kTHJHyIkrJrME4w +0Njrd/+tvjX/AsBfj+4IqZo6yMSVQuGhnJYEWt+eU3+5OlLsWe8oy9PJN2hVPxvG +bWaMx0/I88gSf9DQRnpIL2Kr715c4NUqy+82DP+Tg6h/lQ1oba68redtZoctDrmq +IXox6PkCgYEA+fElqxWqdDTOw+yhD4CeZ6yxipI/iUQu2t7atRmTgi2WAgaT5aAP +1lWhTmWPW7IzjYzFe/CK7OAe0P7JgmvBI2SMlSg/NWOx+pnTmOVO3buz46+B3VI9 +IFhMqAkVfnXukT0YgLsdvRTZb7irYeelVImZ0VVvn0+HzqZ1JUj8kwMCgYEA5LGK +jZ0NmBQT/yxsp1GWlmIJnTizlxEGoc7ftL4SNGcshMKbSK6aYnUvjnnRdFwjQUu4 +mAN0LlDn7SalEL7PUnotUMNA278o4Zj88VcWQqjukkThDpOFVcreXuGkqIfxYyn6 +jJxLouF6L0zspEhFiEWjNYOVDzZw7Fh0tOCxkfMCf1L8voUPrIjo/74N02xSSEYk +EM7xwCbTfLsvQ27eDxwqBqSlinWzr4564BQnpHHNuVBGbUu5kmcUAydhcYbcQESA +Hi1oL5SKhY2vhZI+kPEOYaw3mebiZ2lV6B3i5kAW6B9RKdGUT0t4oLl3l2/qefqX +tXrL40QCJBV5L2wxz6sCgYEAoDRXUTkSCtUV5Q3j15pqGVL4VTEhbdQ5hyR6xgzY +h+k24JHLYjEeaZaaB/8CYbch413+JE9XFhMLRbBqtb5VUfvQvuDpEIdrRg58MzzE +lVHuPn0OA74IC7+f42vCg2UoDkWcBOCAg8vcYkJLDBKs0veli5lv1EZY+NhGeWdm +PU0CgYAhEHZnVC8DuKAUxuIXEpDil3F7iGYs1rAnGd3GkKofiait+9YlsCxFx/4F +95VQjHm6Fdc+vwGUa2Z986wKmocWzVP3TbznMdbvr0/8LCOhQKDrtCkFWHtGsP/d +PnCkwIdaTEen0E52PkMK8GNq6wjitINzRp5hpV23WFtGQxmmlA== +-----END RSA PRIVATE KEY----- diff --git a/examples/modify.go b/examples/modify.go index 326598c..87d1119 100644 --- a/examples/modify.go +++ b/examples/modify.go @@ -11,7 +11,7 @@ import ( "fmt" "log" - "github.com/vanackere/ldap" + "github.com/nmcclain/ldap" ) var ( diff --git a/examples/proxy.go b/examples/proxy.go new file mode 100644 index 0000000..d6b01d0 --- /dev/null +++ b/examples/proxy.go @@ -0,0 +1,106 @@ +package main + +import ( + "crypto/sha256" + "fmt" + "github.com/nmcclain/ldap" + "log" + "net" + "sync" +) + +type ldapHandler struct { + sessions map[string]session + lock sync.Mutex + ldapServer string + ldapPort int +} + +///////////// Run a simple LDAP proxy +func main() { + s := ldap.NewServer() + + handler := ldapHandler{ + sessions: make(map[string]session), + ldapServer: "localhost", + ldapPort: 3389, + } + s.BindFunc("", handler) + s.SearchFunc("", handler) + s.CloseFunc("", handler) + + // start the server + if err := s.ListenAndServe("localhost:3388"); err != nil { + log.Fatal("LDAP Server Failed: %s", err.Error()) + } +} + +///////////// +type session struct { + id string + c net.Conn + ldap *ldap.Conn +} + +func (h ldapHandler) getSession(conn net.Conn) (session, error) { + id := connID(conn) + h.lock.Lock() + s, ok := h.sessions[id] // use server connection if it exists + h.lock.Unlock() + if !ok { // open a new server connection if not + l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", h.ldapServer, h.ldapPort)) + if err != nil { + return session{}, err + } + s = session{id: id, c: conn, ldap: l} + h.lock.Lock() + h.sessions[s.id] = s + h.lock.Unlock() + } + return s, nil +} + +///////////// +func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + s, err := h.getSession(conn) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + if err := s.ldap.Bind(bindDN, bindSimplePw); err != nil { + return ldap.LDAPResultOperationsError, err + } + return ldap.LDAPResultSuccess, nil +} + +///////////// +func (h ldapHandler) Search(boundDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) { + s, err := h.getSession(conn) + if err != nil { + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, nil + } + search := ldap.NewSearchRequest( + searchReq.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + searchReq.Filter, + searchReq.Attributes, + nil) + sr, err := s.ldap.Search(search) + if err != nil { + return ldap.ServerSearchResult{}, err + } + //log.Printf("P: Search OK: %s -> num of entries = %d\n", search.Filter, len(sr.Entries)) + return ldap.ServerSearchResult{sr.Entries, []string{}, []ldap.Control{}, ldap.LDAPResultSuccess}, nil +} +func (h ldapHandler) Close(conn net.Conn) error { + conn.Close() // close connection to the server when then client is closed + h.lock.Lock() + defer h.lock.Unlock() + delete(h.sessions, connID(conn)) + return nil +} +func connID(conn net.Conn) string { + h := sha256.New() + h.Write([]byte(conn.LocalAddr().String() + conn.RemoteAddr().String())) + sha := fmt.Sprintf("% x", h.Sum(nil)) + return string(sha) +} diff --git a/examples/search.go b/examples/search.go index 93e941c..08b364a 100644 --- a/examples/search.go +++ b/examples/search.go @@ -10,7 +10,7 @@ import ( "fmt" "log" - "github.com/vanackere/ldap" + "github.com/nmcclain/ldap" ) var ( diff --git a/examples/searchSSL.go b/examples/searchSSL.go index db2f7b8..75c8395 100644 --- a/examples/searchSSL.go +++ b/examples/searchSSL.go @@ -10,7 +10,7 @@ import ( "fmt" "log" - "github.com/vanackere/ldap" + "github.com/nmcclain/ldap" ) var ( diff --git a/examples/searchTLS.go b/examples/searchTLS.go index b4dce8a..56b3d27 100644 --- a/examples/searchTLS.go +++ b/examples/searchTLS.go @@ -10,7 +10,7 @@ import ( "fmt" "log" - "github.com/vanackere/ldap" + "github.com/nmcclain/ldap" ) var ( diff --git a/examples/server.go b/examples/server.go new file mode 100644 index 0000000..dca74ed --- /dev/null +++ b/examples/server.go @@ -0,0 +1,64 @@ +package main + +import ( + "github.com/nmcclain/ldap" + "log" + "net" +) + +///////////// +// Sample searches you can try against this simple LDAP server: +// +// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com' +// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com' 'cn=ned' +// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com' 'uidnumber=5000' +///////////// + +///////////// Run a simple LDAP server +func main() { + s := ldap.NewServer() + + // register Bind and Search function handlers + handler := ldapHandler{} + s.BindFunc("", handler) + s.SearchFunc("", handler) + + // start the server + if err := s.ListenAndServe("localhost:3389"); err != nil { + log.Fatal("LDAP Server Failed: %s", err.Error()) + } +} + +type ldapHandler struct { +} + +///////////// Allow anonymous binds only +func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + if bindDN == "" && bindSimplePw == "" { + return ldap.LDAPResultSuccess, nil + } + return ldap.LDAPResultInvalidCredentials, nil +} + +///////////// Return some hardcoded search results - we'll respond to any baseDN for testing +func (h ldapHandler) Search(boundDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) { + entries := []*ldap.Entry{ + &ldap.Entry{"cn=ned," + searchReq.BaseDN, []*ldap.EntryAttribute{ + &ldap.EntryAttribute{"cn", []string{"ned"}}, + &ldap.EntryAttribute{"uidNumber", []string{"5000"}}, + &ldap.EntryAttribute{"accountStatus", []string{"active"}}, + &ldap.EntryAttribute{"uid", []string{"ned"}}, + &ldap.EntryAttribute{"description", []string{"ned"}}, + &ldap.EntryAttribute{"objectClass", []string{"posixAccount"}}, + }}, + &ldap.Entry{"cn=trent," + searchReq.BaseDN, []*ldap.EntryAttribute{ + &ldap.EntryAttribute{"cn", []string{"trent"}}, + &ldap.EntryAttribute{"uidNumber", []string{"5005"}}, + &ldap.EntryAttribute{"accountStatus", []string{"active"}}, + &ldap.EntryAttribute{"uid", []string{"trent"}}, + &ldap.EntryAttribute{"description", []string{"trent"}}, + &ldap.EntryAttribute{"objectClass", []string{"posixAccount"}}, + }}, + } + return ldap.ServerSearchResult{entries, []string{}, []ldap.Control{}, ldap.LDAPResultSuccess}, nil +} diff --git a/filter.go b/filter.go index 690f67d..d7bc798 100644 --- a/filter.go +++ b/filter.go @@ -7,8 +7,8 @@ package ldap import ( "errors" "fmt" - - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" + "strings" ) const ( @@ -24,7 +24,7 @@ const ( FilterExtensibleMatch = 9 ) -var filterMap = map[uint8]string{ +var FilterMap = map[uint8]string{ FilterAnd: "And", FilterOr: "Or", FilterNot: "Not", @@ -163,15 +163,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { newPos++ return packet, newPos, err case '&': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, filterMap[FilterAnd]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) newPos, err = compileFilterSet(filter, pos+1, packet) return packet, newPos, err case '|': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, filterMap[FilterOr]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) newPos, err = compileFilterSet(filter, pos+1, packet) return packet, newPos, err case '!': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, filterMap[FilterNot]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) var child *ber.Packet child, newPos, err = compileFilter(filter, pos+1) packet.AppendChild(child) @@ -184,15 +184,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { case packet != nil: condition += fmt.Sprintf("%c", filter[newPos]) case filter[newPos] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, filterMap[FilterEqualityMatch]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) case filter[newPos] == '>' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, filterMap[FilterGreaterOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) newPos++ case filter[newPos] == '<' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, filterMap[FilterLessOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) newPos++ case filter[newPos] == '~' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, filterMap[FilterLessOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual]) newPos++ case packet == nil: attribute += fmt.Sprintf("%c", filter[newPos]) @@ -211,7 +211,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { if packet.Tag == FilterEqualityMatch && condition == "*" { packet.TagType = ber.TypePrimitive packet.Tag = FilterPresent - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] packet.Data.WriteString(attribute) return packet, newPos + 1, nil } @@ -220,21 +220,21 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*': // Any packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring")) packet.AppendChild(seq) case packet.Tag == FilterEqualityMatch && condition[0] == '*': // Final packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring")) packet.AppendChild(seq) case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*': // Initial packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring")) packet.AppendChild(seq) @@ -245,3 +245,136 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { return packet, newPos, err } } + +func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) { + //log.Printf("%# v", pretty.Formatter(entry)) + + switch FilterMap[f.Tag] { + default: + //log.Fatalf("Unknown LDAP filter code: %d", f.Tag) + return false, LDAPResultOperationsError + case "Equality Match": + if len(f.Children) != 2 { + return false, LDAPResultOperationsError + } + attribute := f.Children[0].Value.(string) + value := f.Children[1].Value.(string) + for _, a := range entry.Attributes { + if strings.ToLower(a.Name) == strings.ToLower(attribute) { + for _, v := range a.Values { + if strings.ToLower(v) == strings.ToLower(value) { + return true, LDAPResultSuccess + } + } + } + } + case "Present": + for _, a := range entry.Attributes { + if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) { + return true, LDAPResultSuccess + } + } + case "And": + for _, child := range f.Children { + ok, exitCode := ServerApplyFilter(child, entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } + if !ok { + return false, LDAPResultSuccess + } + } + return true, LDAPResultSuccess + case "Or": + anyOk := false + for _, child := range f.Children { + ok, exitCode := ServerApplyFilter(child, entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } else if ok { + anyOk = true + } + } + if anyOk { + return true, LDAPResultSuccess + } + case "Not": + if len(f.Children) != 1 { + return false, LDAPResultOperationsError + } + ok, exitCode := ServerApplyFilter(f.Children[0], entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } else if !ok { + return true, LDAPResultSuccess + } + case "FilterSubstrings": + return false, LDAPResultOperationsError + case "FilterGreaterOrEqual": + return false, LDAPResultOperationsError + case "FilterLessOrEqual": + return false, LDAPResultOperationsError + case "FilterApproxMatch": + return false, LDAPResultOperationsError + case "FilterExtensibleMatch": + return false, LDAPResultOperationsError + } + + return false, LDAPResultSuccess +} + +func GetFilterType(filter string) (string, error) { // TODO <- test this + f, err := CompileFilter(filter) + if err != nil { + return "", err + } + return parseFilterType(f) +} +func parseFilterType(f *ber.Packet) (string, error) { + searchType := "" + switch FilterMap[f.Tag] { + case "Equality Match": + if len(f.Children) != 2 { + return "", errors.New("Equality match must have only two children") + } + attribute := strings.ToLower(f.Children[0].Value.(string)) + value := f.Children[1].Value.(string) + + if attribute == "objectclass" { + searchType = strings.ToLower(value) + } + case "And": + for _, child := range f.Children { + subType, err := parseFilterType(child) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + } + case "Or": + for _, child := range f.Children { + subType, err := parseFilterType(child) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + } + case "Not": + if len(f.Children) != 1 { + return "", errors.New("Not filter must have only one child") + } + subType, err := parseFilterType(f.Children[0]) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + + } + return strings.ToLower(searchType), nil +} diff --git a/filter_test.go b/filter_test.go index baecab9..fb54905 100644 --- a/filter_test.go +++ b/filter_test.go @@ -4,7 +4,7 @@ import ( "reflect" "testing" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) type compileTest struct { @@ -34,7 +34,7 @@ func TestFilter(t *testing.T) { if err != nil { t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) } else if filter.Tag != uint8(i.filterType) { - t.Errorf("%q Expected %q got %q", i.filterStr, filterMap[i.filterType], filterMap[filter.Tag]) + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[i.filterType], FilterMap[filter.Tag]) } else { o, err := DecompileFilter(filter) if err != nil { diff --git a/ldap.go b/ldap.go index 08c01e8..42c50d6 100644 --- a/ldap.go +++ b/ldap.go @@ -9,7 +9,7 @@ import ( "fmt" "io/ioutil" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) // LDAP Application Codes @@ -149,6 +149,12 @@ var LDAPResultCodeMap = map[uint8]string{ LDAPResultOther: "Other", } +// Other LDAP constants +const ( + LDAPBindAuthSimple = 0 + LDAPBindAuthSASL = 3 +) + // Adds descriptions to an LDAP Response packet for debugging func addLDAPDescriptions(packet *ber.Packet) (err error) { defer func() { diff --git a/modify.go b/modify.go index 8c0e852..7decf2c 100644 --- a/modify.go +++ b/modify.go @@ -33,7 +33,7 @@ import ( "errors" "log" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) const ( diff --git a/search.go b/search.go index d0ce3f5..45b26b8 100644 --- a/search.go +++ b/search.go @@ -64,7 +64,7 @@ import ( "fmt" "strings" - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" ) const ( diff --git a/server.go b/server.go new file mode 100644 index 0000000..4a46e6f --- /dev/null +++ b/server.go @@ -0,0 +1,595 @@ +package ldap + +import ( + "crypto/tls" + "errors" + "fmt" + "github.com/nmcclain/asn1-ber" + "io" + "log" + "net" + "strings" + "sync" +) + +type Binder interface { + Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) +} +type Searcher interface { + Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) +} +type Closer interface { + Close(conn net.Conn) error +} + +///////////////////////// +type Server struct { + bindFns map[string]Binder + searchFns map[string]Searcher + closeFns map[string]Closer + quit chan bool + EnforceLDAP bool + stats *Stats +} + +type Stats struct { + Conns int + Binds int + Unbinds int + Searches int + statsMutex sync.Mutex +} + +type ServerSearchResult struct { + Entries []*Entry + Referrals []string + Controls []Control + ResultCode uint64 +} + +///////////////////////// +func NewServer() *Server { + s := new(Server) + s.quit = make(chan bool) + + d := defaultHandler{} + s.bindFns = make(map[string]Binder) + s.searchFns = make(map[string]Searcher) + s.closeFns = make(map[string]Closer) + s.bindFns[""] = d + s.searchFns[""] = d + s.closeFns[""] = d + s.stats = nil + return s +} +func (server *Server) BindFunc(baseDN string, bindFn Binder) { + server.bindFns[baseDN] = bindFn +} +func (server *Server) SearchFunc(baseDN string, searchFn Searcher) { + server.searchFns[baseDN] = searchFn +} +func (server *Server) CloseFunc(baseDN string, closeFn Closer) { + server.closeFns[baseDN] = closeFn +} +func (server *Server) QuitChannel(quit chan bool) { + server.quit = quit +} + +func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} + tlsConfig.ServerName = "localhost" + ln, err := tls.Listen("tcp", listenString, &tlsConfig) + if err != nil { + return err + } + err = server.serve(ln) + if err != nil { + return err + } + return nil +} + +func (server *Server) SetStats(enable bool) { + if enable { + server.stats = &Stats{} + } else { + server.stats = nil + } +} + +func (server *Server) GetStats() Stats { + defer func() { + server.stats.statsMutex.Unlock() + }() + server.stats.statsMutex.Lock() + return *server.stats +} + +func (server *Server) ListenAndServe(listenString string) error { + ln, err := net.Listen("tcp", listenString) + if err != nil { + return err + } + err = server.serve(ln) + if err != nil { + return err + } + return nil +} + +func (server *Server) serve(ln net.Listener) error { + newConn := make(chan net.Conn) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + log.Printf("Error accepting network connection: %s", err.Error()) + } + break + } + newConn <- conn + } + }() + +listener: + for { + select { + case c := <-newConn: + server.stats.countConns(1) + go server.handleConnection(c) + case <-server.quit: + ln.Close() + break listener + } + } + return nil +} + +///////////////////////// + +func (server *Server) handleConnection(conn net.Conn) { + boundDN := "" // "" == anonymous + +handler: + for { + // read incoming LDAP packet + packet, err := ber.ReadPacket(conn) + if err == io.EOF { // Client closed connection + break + } else if err != nil { + log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error()) + break + } + + // sanity check this packet + if len(packet.Children) < 2 { + log.Print("len(packet.Children) < 2") + break + } + // check the message ID and ClassType + messageID := packet.Children[0].Value.(uint64) + req := packet.Children[1] + if req.ClassType != ber.ClassApplication { + log.Print("req.ClassType != ber.ClassApplication") + break + } + // handle controls if present + if len(packet.Children) > 2 { + controls := packet.Children[2] + ber.PrintPacket(controls) + log.Print("TODO Parse Controls") + /* + Controls ::= SEQUENCE OF control Control + + Control ::= SEQUENCE { + controlType LDAPOID, + criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension + controlValue OCTET STRING OPTIONAL } + */ + } + + // dispatch the LDAP operation + switch req.Tag { // ldap op code + default: + //log.Printf("Bound as %s", boundDN) + //ber.PrintPacket(packet) + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + + case ApplicationBindRequest: + server.stats.countBinds(1) + ldapResultCode := server.handleBindRequest(req, server.bindFns, conn) + if ldapResultCode == LDAPResultSuccess { + boundDN = req.Children[1].Value.(string) + } + responsePacket := encodeBindResponse(messageID, ldapResultCode) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler + } + case ApplicationSearchRequest: + server.stats.countSearches(1) + if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil { + log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks? + e := err.(*Error) + if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + break handler + } else { + if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler + } + } + case ApplicationUnbindRequest: + server.stats.countUnbinds(1) + break handler // simply disconnect - this IS implemented + case ApplicationExtendedRequest: + responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request") + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + break handler + case ApplicationAbandonRequest: + log.Printf("Abandoning request!") + break handler + + // Unimplemented LDAP operations: + case ApplicationModifyRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationAddRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationDelRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationModifyDNRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationCompareRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + } + } + + for _, c := range server.closeFns { + c.Close(conn) + } + + conn.Close() +} + +///////////////////////// +func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) { + defer func() { + if r := recover(); r != nil { + resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) + } + }() + + searchReq, err := parseSearchRequest(boundDN, req) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + filterPacket, err := CompileFilter(searchReq.Filter) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + fnNames := []string{} + for k := range searchFns { + fnNames = append(fnNames, k) + } + searchFn := routeFunc(searchReq.BaseDN, fnNames) + searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn) + if err != nil { + return NewError(uint8(searchResp.ResultCode), err) + } + + if server.EnforceLDAP { + if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} + // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases + } + if len(searchReq.Controls) > 0 { + return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO + } + if searchReq.TimeLimit > 0 { + return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO + } + } + + for i, entry := range searchResp.Entries { + if server.EnforceLDAP { + // size limit + if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { + break + } + + // filter + keep, resultCode := ServerApplyFilter(filterPacket, entry) + if resultCode != LDAPResultSuccess { + return NewError(uint8(resultCode), errors.New("ServerApplyFilter error")) + } + if !keep { + continue + } + + // constrained search scope + switch searchReq.Scope { + case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. + case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. + if entry.DN != searchReq.BaseDN { + continue + } + case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. + parts := strings.Split(entry.DN, ",") + if len(parts) < 2 && entry.DN != searchReq.BaseDN { + continue + } + if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { + continue + } + } + + // attributes + if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { + entry, err = filterAttributes(entry, searchReq.Attributes) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + } + + // respond + responsePacket := encodeSearchResponse(messageID, searchReq, entry) + if err = sendPacket(conn, responsePacket); err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + return nil +} + +///////////////////////// +func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) { + defer func() { + if r := recover(); r != nil { + resultCode = LDAPResultOperationsError + } + }() + + // we only support ldapv3 + ldapVersion := req.Children[0].Value.(uint64) + if ldapVersion != 3 { + log.Printf("Unsupported LDAP version: %d", ldapVersion) + return LDAPResultInappropriateAuthentication + } + + // auth types + bindDN := req.Children[1].Value.(string) + bindAuth := req.Children[2] + switch bindAuth.Tag { + default: + log.Print("Unknown LDAP authentication method") + return LDAPResultInappropriateAuthentication + case LDAPBindAuthSimple: + if len(req.Children) == 3 { + fnNames := []string{} + for k := range bindFns { + fnNames = append(fnNames, k) + } + bindFn := routeFunc(bindDN, fnNames) + resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn) + if err != nil { + log.Printf("BindFn Error %s", err.Error()) + } + return resultCode + } else { + log.Print("Simple bind request has wrong # children. len(req.Children) != 3") + return LDAPResultInappropriateAuthentication + } + case LDAPBindAuthSASL: + log.Print("SASL authentication is not supported") + return LDAPResultInappropriateAuthentication + } + return LDAPResultOperationsError +} + +///////////////////////// +func sendPacket(conn net.Conn, packet *ber.Packet) error { + _, err := conn.Write(packet.Bytes()) + if err != nil { + log.Printf("Error Sending Message: %s", err.Error()) + return err + } + return nil +} + +///////////////////////// +func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) { + if len(req.Children) != 8 { + return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) + } + + // Parse the request + baseObject := req.Children[0].Value.(string) + scope := int(req.Children[1].Value.(uint64)) + derefAliases := int(req.Children[2].Value.(uint64)) + sizeLimit := int(req.Children[3].Value.(uint64)) + timeLimit := int(req.Children[4].Value.(uint64)) + typesOnly := false + if req.Children[5].Value != nil { + typesOnly = req.Children[5].Value.(bool) + } + filter, err := DecompileFilter(req.Children[6]) + if err != nil { + return SearchRequest{}, err + } + attributes := []string{} + for _, attr := range req.Children[7].Children { + attributes = append(attributes, attr.Value.(string)) + } + searchReq := SearchRequest{baseObject, scope, + derefAliases, sizeLimit, timeLimit, + typesOnly, filter, attributes, nil} + + return searchReq, nil +} + +///////////////////////// +func routeFunc(dn string, funcNames []string) string { + bestPick := "" + for _, fn := range funcNames { + if strings.HasSuffix(dn, fn) { + l := len(strings.Split(bestPick, ",")) + if bestPick == "" { + l = 0 + } + if len(strings.Split(fn, ",")) > l { + bestPick = fn + } + } + } + return bestPick +} + +///////////////////////// +func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { + // only return requested attributes + newAttributes := []*EntryAttribute{} + + for _, attr := range entry.Attributes { + for _, requested := range attributes { + if strings.ToLower(attr.Name) == strings.ToLower(requested) { + newAttributes = append(newAttributes, attr) + } + } + } + entry.Attributes = newAttributes + + return entry, nil +} + +///////////////////////// +type defaultHandler struct { +} + +func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + return LDAPResultInappropriateAuthentication, nil +} +func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil +} +func (h defaultHandler) Close(conn net.Conn) error { + conn.Close() + return nil +} + +///////////////////////// +func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + + responsePacket.AppendChild(bindReponse) + + // ber.PrintPacket(responsePacket) + return responsePacket +} +func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") + searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) + + attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") + for _, attribute := range res.Attributes { + attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) + } + + searchEntry.AppendChild(attrs) + responsePacket.AppendChild(searchEntry) + + return responsePacket +} + +func encodeSearchAttribute(name string, values []string) *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) + + valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") + for _, value := range values { + valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) + } + + packet.AppendChild(valuesPacket) + + return packet +} + +func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") + donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + responsePacket.AppendChild(donePacket) + + return responsePacket +} + +func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) + reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) + responsePacket.AppendChild(reponse) + return responsePacket +} + +///////////////////////// +func (stats *Stats) countConns(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Conns += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countBinds(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Binds += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countUnbinds(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Unbinds += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countSearches(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Searches += delta + stats.statsMutex.Unlock() + } +} + +///////////////////////// diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..9386a4a --- /dev/null +++ b/server_test.go @@ -0,0 +1,732 @@ +package ldap + +import ( + "bytes" + "log" + "net" + "os/exec" + "strings" + "testing" + "time" +) + +var listenString = "localhost:3389" +var ldapURL = "ldap://" + listenString +var timeout = 400 * time.Millisecond +var serverBaseDN = "o=testers,c=test" + +///////////////////////// +func TestBindAnonOK(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindAnonOK{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindAnonFail(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + time.Sleep(timeout) + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_bind: Inappropriate authentication (48)") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + time.Sleep(timeout) + quit <- true +} + +///////////////////////// +func TestBindSimpleOK(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindSimpleFailBadPw(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "BADPassword") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") { + t.Errorf("ldapsearch succeeded - should have failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindSimpleFailBadDn(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testoy,"+serverBaseDN, "-w", "iLike2test") + out, _ := cmd.CombinedOutput() + if string(out) != "ldap_bind: Invalid credentials (49)\n" { + t.Errorf("ldapsearch succeeded - should have failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindSSL(t *testing.T) { + ldapURLSSL := "ldaps://" + listenString + longerTimeout := 300 * time.Millisecond + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindAnonOK{}) + if err := s.ListenAndServeTLS(listenString, "examples/cert_DONOTUSE.pem", "examples/key_DONOTUSE.pem"); err != nil { + t.Errorf("s.ListenAndServeTLS failed: %s", err.Error()) + } + }() + + go func() { + time.Sleep(longerTimeout * 2) + cmd := exec.Command("ldapsearch", "-H", ldapURLSSL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(longerTimeout * 2): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindPanic(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindPanic{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_bind: Operations error") { + t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestSearchSimpleOK(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "uidNumber: 5000") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numResponses: 4") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +func TestSearchSizelimit(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numEntries: 3") { + t.Errorf("ldapsearch sizelimit failed - not enough entries: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numEntries: 2") { + t.Errorf("ldapsearch sizelimit failed - too many entries: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindSearchMulti(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindSimple{}) + s.BindFunc("c=testz", bindSimple2{}) + s.SearchFunc("", searchSimple{}) + s.SearchFunc("c=testz", searchSimple2{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test", + "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing default bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("search default routing failed: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz", + "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing custom bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { + t.Errorf("search custom routing failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + + quit <- true +} + +///////////////////////// +func TestSearchPanic(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchPanic{}) + s.BindFunc("", bindAnonOK{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 1 Operations error") { + t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +type compileSearchFilterTest struct { + name string + filterStr string + numResponses string +} + +var searchFilterTestFilters = []compileSearchFilterTest{ + compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"}, + compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"}, + compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"}, + compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"}, + compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"}, + compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"}, + compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"}, + compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"}, + compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"}, + compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"}, + compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"}, + compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"}, + compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"}, + compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"}, + compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"}, + compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"}, + compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"}, + /* + compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual}, + compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual}, + compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch}, + */ +} + +///////////////////////// +func TestSearchFiltering(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + for _, i := range searchFilterTestFilters { + t.Log(i.name) + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr) + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "numResponses: "+i.numResponses) { + t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + } + quit <- true +} + +///////////////////////// +func TestSearchAttributes(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + filterString := "" + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn") + out, _ := cmd.CombinedOutput() + + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out)) + } + if !strings.Contains(string(out), "cn: ned") { + t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out)) + } + if strings.Contains(string(out), "uidNumber") { + t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out)) + } + if strings.Contains(string(out), "accountstatus") { + t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestSearchScope(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out)) + } + + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +type testStatsWriter struct { + buffer *bytes.Buffer +} + +func (tsw testStatsWriter) Write(buf []byte) (int, error) { + tsw.buffer.Write(buf) + return len(buf), nil +} + +func TestSearchStats(t *testing.T) { + w := testStatsWriter{&bytes.Buffer{}} + log.SetOutput(w) + + quit := make(chan bool) + done := make(chan bool) + s := NewServer() + + go func() { + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindAnonOK{}) + s.SetStats(true) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + + stats := s.GetStats() + if stats.Conns != 1 || stats.Binds != 1 { + t.Errorf("Stats data missing or incorrect: %v", w.buffer.String()) + } + quit <- true +} + +///////////////////////// +type bindAnonOK struct { +} + +func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + if bindDN == "" && bindSimplePw == "" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} + +type bindSimple struct { +} + +func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} + +type bindSimple2 struct { +} + +func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} + +type bindPanic struct { +} + +func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + panic("test panic at the disco") + return LDAPResultInvalidCredentials, nil +} + +type searchSimple struct { +} + +func (s searchSimple) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{ + &Entry{"cn=ned,o=testers,c=test", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"ned"}}, + &EntryAttribute{"o", []string{"ate"}}, + &EntryAttribute{"uidNumber", []string{"5000"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"ned"}}, + &EntryAttribute{"description", []string{"ned via sa"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }}, + &Entry{"cn=trent,o=testers,c=test", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"trent"}}, + &EntryAttribute{"o", []string{"ate"}}, + &EntryAttribute{"uidNumber", []string{"5005"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"trent"}}, + &EntryAttribute{"description", []string{"trent via sa"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }}, + &Entry{"cn=randy,o=testers,c=test", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"randy"}}, + &EntryAttribute{"o", []string{"ate"}}, + &EntryAttribute{"uidNumber", []string{"5555"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"randy"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }}, + } + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} + +type searchSimple2 struct { +} + +func (s searchSimple2) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{ + &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"hamburger"}}, + &EntryAttribute{"o", []string{"testers"}}, + &EntryAttribute{"uidNumber", []string{"5000"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"hamburger"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }}, + } + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} + +type searchPanic struct { +} + +func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{} + panic("this is a test panic") + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} -- cgit v1.2.3