summaryrefslogtreecommitdiffstats
path: root/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'server.go')
-rw-r--r--server.go595
1 files changed, 595 insertions, 0 deletions
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..4a46e6f
--- /dev/null
+++ b/server.go
@@ -0,0 +1,595 @@
+package ldap
+
+import (
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+ "io"
+ "log"
+ "net"
+ "strings"
+ "sync"
+)
+
+type Binder interface {
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error)
+}
+type Searcher interface {
+ Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Closer interface {
+ Close(conn net.Conn) error
+}
+
+/////////////////////////
+type Server struct {
+ bindFns map[string]Binder
+ searchFns map[string]Searcher
+ closeFns map[string]Closer
+ quit chan bool
+ EnforceLDAP bool
+ stats *Stats
+}
+
+type Stats struct {
+ Conns int
+ Binds int
+ Unbinds int
+ Searches int
+ statsMutex sync.Mutex
+}
+
+type ServerSearchResult struct {
+ Entries []*Entry
+ Referrals []string
+ Controls []Control
+ ResultCode uint64
+}
+
+/////////////////////////
+func NewServer() *Server {
+ s := new(Server)
+ s.quit = make(chan bool)
+
+ d := defaultHandler{}
+ s.bindFns = make(map[string]Binder)
+ s.searchFns = make(map[string]Searcher)
+ s.closeFns = make(map[string]Closer)
+ s.bindFns[""] = d
+ s.searchFns[""] = d
+ s.closeFns[""] = d
+ s.stats = nil
+ return s
+}
+func (server *Server) BindFunc(baseDN string, bindFn Binder) {
+ server.bindFns[baseDN] = bindFn
+}
+func (server *Server) SearchFunc(baseDN string, searchFn Searcher) {
+ server.searchFns[baseDN] = searchFn
+}
+func (server *Server) CloseFunc(baseDN string, closeFn Closer) {
+ server.closeFns[baseDN] = closeFn
+}
+func (server *Server) QuitChannel(quit chan bool) {
+ server.quit = quit
+}
+
+func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
+ cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+ tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
+ tlsConfig.ServerName = "localhost"
+ ln, err := tls.Listen("tcp", listenString, &tlsConfig)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) SetStats(enable bool) {
+ if enable {
+ server.stats = &Stats{}
+ } else {
+ server.stats = nil
+ }
+}
+
+func (server *Server) GetStats() Stats {
+ defer func() {
+ server.stats.statsMutex.Unlock()
+ }()
+ server.stats.statsMutex.Lock()
+ return *server.stats
+}
+
+func (server *Server) ListenAndServe(listenString string) error {
+ ln, err := net.Listen("tcp", listenString)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) serve(ln net.Listener) error {
+ newConn := make(chan net.Conn)
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ if !strings.HasSuffix(err.Error(), "use of closed network connection") {
+ log.Printf("Error accepting network connection: %s", err.Error())
+ }
+ break
+ }
+ newConn <- conn
+ }
+ }()
+
+listener:
+ for {
+ select {
+ case c := <-newConn:
+ server.stats.countConns(1)
+ go server.handleConnection(c)
+ case <-server.quit:
+ ln.Close()
+ break listener
+ }
+ }
+ return nil
+}
+
+/////////////////////////
+
+func (server *Server) handleConnection(conn net.Conn) {
+ boundDN := "" // "" == anonymous
+
+handler:
+ for {
+ // read incoming LDAP packet
+ packet, err := ber.ReadPacket(conn)
+ if err == io.EOF { // Client closed connection
+ break
+ } else if err != nil {
+ log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
+ break
+ }
+
+ // sanity check this packet
+ if len(packet.Children) < 2 {
+ log.Print("len(packet.Children) < 2")
+ break
+ }
+ // check the message ID and ClassType
+ messageID := packet.Children[0].Value.(uint64)
+ req := packet.Children[1]
+ if req.ClassType != ber.ClassApplication {
+ log.Print("req.ClassType != ber.ClassApplication")
+ break
+ }
+ // handle controls if present
+ if len(packet.Children) > 2 {
+ controls := packet.Children[2]
+ ber.PrintPacket(controls)
+ log.Print("TODO Parse Controls")
+ /*
+ Controls ::= SEQUENCE OF control Control
+
+ Control ::= SEQUENCE {
+ controlType LDAPOID,
+ criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension
+ controlValue OCTET STRING OPTIONAL }
+ */
+ }
+
+ // dispatch the LDAP operation
+ switch req.Tag { // ldap op code
+ default:
+ //log.Printf("Bound as %s", boundDN)
+ //ber.PrintPacket(packet)
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+
+ case ApplicationBindRequest:
+ server.stats.countBinds(1)
+ ldapResultCode := server.handleBindRequest(req, server.bindFns, conn)
+ if ldapResultCode == LDAPResultSuccess {
+ boundDN = req.Children[1].Value.(string)
+ }
+ responsePacket := encodeBindResponse(messageID, ldapResultCode)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationSearchRequest:
+ server.stats.countSearches(1)
+ if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil {
+ log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
+ e := err.(*Error)
+ if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ }
+ break handler
+ } else {
+ if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ }
+ case ApplicationUnbindRequest:
+ server.stats.countUnbinds(1)
+ break handler // simply disconnect - this IS implemented
+ case ApplicationExtendedRequest:
+ responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request")
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ }
+ break handler
+ case ApplicationAbandonRequest:
+ log.Printf("Abandoning request!")
+ break handler
+
+ // Unimplemented LDAP operations:
+ case ApplicationModifyRequest:
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+ case ApplicationAddRequest:
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+ case ApplicationDelRequest:
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+ case ApplicationModifyDNRequest:
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+ case ApplicationCompareRequest:
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+ }
+ }
+
+ for _, c := range server.closeFns {
+ c.Close(conn)
+ }
+
+ conn.Close()
+}
+
+/////////////////////////
+func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
+ }
+ }()
+
+ searchReq, err := parseSearchRequest(boundDN, req)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ filterPacket, err := CompileFilter(searchReq.Filter)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ fnNames := []string{}
+ for k := range searchFns {
+ fnNames = append(fnNames, k)
+ }
+ searchFn := routeFunc(searchReq.BaseDN, fnNames)
+ searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn)
+ if err != nil {
+ return NewError(uint8(searchResp.ResultCode), err)
+ }
+
+ if server.EnforceLDAP {
+ if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
+ // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases
+ }
+ if len(searchReq.Controls) > 0 {
+ return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO
+ }
+ if searchReq.TimeLimit > 0 {
+ return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO
+ }
+ }
+
+ for i, entry := range searchResp.Entries {
+ if server.EnforceLDAP {
+ // size limit
+ if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
+ break
+ }
+
+ // filter
+ keep, resultCode := ServerApplyFilter(filterPacket, entry)
+ if resultCode != LDAPResultSuccess {
+ return NewError(uint8(resultCode), errors.New("ServerApplyFilter error"))
+ }
+ if !keep {
+ continue
+ }
+
+ // constrained search scope
+ switch searchReq.Scope {
+ case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
+ case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
+ if entry.DN != searchReq.BaseDN {
+ continue
+ }
+ case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
+ parts := strings.Split(entry.DN, ",")
+ if len(parts) < 2 && entry.DN != searchReq.BaseDN {
+ continue
+ }
+ if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
+ continue
+ }
+ }
+
+ // attributes
+ if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
+ entry, err = filterAttributes(entry, searchReq.Attributes)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ }
+
+ // respond
+ responsePacket := encodeSearchResponse(messageID, searchReq, entry)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ return nil
+}
+
+/////////////////////////
+func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultCode = LDAPResultOperationsError
+ }
+ }()
+
+ // we only support ldapv3
+ ldapVersion := req.Children[0].Value.(uint64)
+ if ldapVersion != 3 {
+ log.Printf("Unsupported LDAP version: %d", ldapVersion)
+ return LDAPResultInappropriateAuthentication
+ }
+
+ // auth types
+ bindDN := req.Children[1].Value.(string)
+ bindAuth := req.Children[2]
+ switch bindAuth.Tag {
+ default:
+ log.Print("Unknown LDAP authentication method")
+ return LDAPResultInappropriateAuthentication
+ case LDAPBindAuthSimple:
+ if len(req.Children) == 3 {
+ fnNames := []string{}
+ for k := range bindFns {
+ fnNames = append(fnNames, k)
+ }
+ bindFn := routeFunc(bindDN, fnNames)
+ resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn)
+ if err != nil {
+ log.Printf("BindFn Error %s", err.Error())
+ }
+ return resultCode
+ } else {
+ log.Print("Simple bind request has wrong # children. len(req.Children) != 3")
+ return LDAPResultInappropriateAuthentication
+ }
+ case LDAPBindAuthSASL:
+ log.Print("SASL authentication is not supported")
+ return LDAPResultInappropriateAuthentication
+ }
+ return LDAPResultOperationsError
+}
+
+/////////////////////////
+func sendPacket(conn net.Conn, packet *ber.Packet) error {
+ _, err := conn.Write(packet.Bytes())
+ if err != nil {
+ log.Printf("Error Sending Message: %s", err.Error())
+ return err
+ }
+ return nil
+}
+
+/////////////////////////
+func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) {
+ if len(req.Children) != 8 {
+ return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
+ }
+
+ // Parse the request
+ baseObject := req.Children[0].Value.(string)
+ scope := int(req.Children[1].Value.(uint64))
+ derefAliases := int(req.Children[2].Value.(uint64))
+ sizeLimit := int(req.Children[3].Value.(uint64))
+ timeLimit := int(req.Children[4].Value.(uint64))
+ typesOnly := false
+ if req.Children[5].Value != nil {
+ typesOnly = req.Children[5].Value.(bool)
+ }
+ filter, err := DecompileFilter(req.Children[6])
+ if err != nil {
+ return SearchRequest{}, err
+ }
+ attributes := []string{}
+ for _, attr := range req.Children[7].Children {
+ attributes = append(attributes, attr.Value.(string))
+ }
+ searchReq := SearchRequest{baseObject, scope,
+ derefAliases, sizeLimit, timeLimit,
+ typesOnly, filter, attributes, nil}
+
+ return searchReq, nil
+}
+
+/////////////////////////
+func routeFunc(dn string, funcNames []string) string {
+ bestPick := ""
+ for _, fn := range funcNames {
+ if strings.HasSuffix(dn, fn) {
+ l := len(strings.Split(bestPick, ","))
+ if bestPick == "" {
+ l = 0
+ }
+ if len(strings.Split(fn, ",")) > l {
+ bestPick = fn
+ }
+ }
+ }
+ return bestPick
+}
+
+/////////////////////////
+func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
+ // only return requested attributes
+ newAttributes := []*EntryAttribute{}
+
+ for _, attr := range entry.Attributes {
+ for _, requested := range attributes {
+ if strings.ToLower(attr.Name) == strings.ToLower(requested) {
+ newAttributes = append(newAttributes, attr)
+ }
+ }
+ }
+ entry.Attributes = newAttributes
+
+ return entry, nil
+}
+
+/////////////////////////
+type defaultHandler struct {
+}
+
+func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+ return LDAPResultInappropriateAuthentication, nil
+}
+func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+func (h defaultHandler) Close(conn net.Conn) error {
+ conn.Close()
+ return nil
+}
+
+/////////////////////////
+func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
+ bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+
+ responsePacket.AppendChild(bindReponse)
+
+ // ber.PrintPacket(responsePacket)
+ return responsePacket
+}
+func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
+ searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
+
+ attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
+ for _, attribute := range res.Attributes {
+ attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
+ }
+
+ searchEntry.AppendChild(attrs)
+ responsePacket.AppendChild(searchEntry)
+
+ return responsePacket
+}
+
+func encodeSearchAttribute(name string, values []string) *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
+
+ valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
+ for _, value := range values {
+ valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
+ }
+
+ packet.AppendChild(valuesPacket)
+
+ return packet
+}
+
+func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
+ donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+ responsePacket.AppendChild(donePacket)
+
+ return responsePacket
+}
+
+func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
+ reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
+ responsePacket.AppendChild(reponse)
+ return responsePacket
+}
+
+/////////////////////////
+func (stats *Stats) countConns(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Conns += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countBinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Binds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countUnbinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Unbinds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countSearches(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Searches += delta
+ stats.statsMutex.Unlock()
+ }
+}
+
+/////////////////////////