package ldap import ( "crypto/tls" "errors" "fmt" "github.com/nmcclain/asn1-ber" "io" "log" "net" "strings" "sync" ) type Binder interface { Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) } type Searcher interface { Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) } type Closer interface { Close(conn net.Conn) error } ///////////////////////// type Server struct { bindFns map[string]Binder searchFns map[string]Searcher closeFns map[string]Closer quit chan bool EnforceLDAP bool stats *Stats } type Stats struct { Conns int Binds int Unbinds int Searches int statsMutex sync.Mutex } type ServerSearchResult struct { Entries []*Entry Referrals []string Controls []Control ResultCode uint64 } ///////////////////////// func NewServer() *Server { s := new(Server) s.quit = make(chan bool) d := defaultHandler{} s.bindFns = make(map[string]Binder) s.searchFns = make(map[string]Searcher) s.closeFns = make(map[string]Closer) s.bindFns[""] = d s.searchFns[""] = d s.closeFns[""] = d s.stats = nil return s } func (server *Server) BindFunc(baseDN string, bindFn Binder) { server.bindFns[baseDN] = bindFn } func (server *Server) SearchFunc(baseDN string, searchFn Searcher) { server.searchFns[baseDN] = searchFn } func (server *Server) CloseFunc(baseDN string, closeFn Closer) { server.closeFns[baseDN] = closeFn } func (server *Server) QuitChannel(quit chan bool) { server.quit = quit } func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return err } tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} tlsConfig.ServerName = "localhost" ln, err := tls.Listen("tcp", listenString, &tlsConfig) if err != nil { return err } err = server.serve(ln) if err != nil { return err } return nil } func (server *Server) SetStats(enable bool) { if enable { server.stats = &Stats{} } else { server.stats = nil } } func (server *Server) GetStats() Stats { defer func() { server.stats.statsMutex.Unlock() }() server.stats.statsMutex.Lock() return *server.stats } func (server *Server) ListenAndServe(listenString string) error { ln, err := net.Listen("tcp", listenString) if err != nil { return err } err = server.serve(ln) if err != nil { return err } return nil } func (server *Server) serve(ln net.Listener) error { newConn := make(chan net.Conn) go func() { for { conn, err := ln.Accept() if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { log.Printf("Error accepting network connection: %s", err.Error()) } break } newConn <- conn } }() listener: for { select { case c := <-newConn: server.stats.countConns(1) go server.handleConnection(c) case <-server.quit: ln.Close() break listener } } return nil } ///////////////////////// func (server *Server) handleConnection(conn net.Conn) { boundDN := "" // "" == anonymous handler: for { // read incoming LDAP packet packet, err := ber.ReadPacket(conn) if err == io.EOF { // Client closed connection break } else if err != nil { log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error()) break } // sanity check this packet if len(packet.Children) < 2 { log.Print("len(packet.Children) < 2") break } // check the message ID and ClassType messageID := packet.Children[0].Value.(uint64) req := packet.Children[1] if req.ClassType != ber.ClassApplication { log.Print("req.ClassType != ber.ClassApplication") break } // handle controls if present if len(packet.Children) > 2 { controls := packet.Children[2] ber.PrintPacket(controls) log.Print("TODO Parse Controls") /* Controls ::= SEQUENCE OF control Control Control ::= SEQUENCE { controlType LDAPOID, criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension controlValue OCTET STRING OPTIONAL } */ } // dispatch the LDAP operation switch req.Tag { // ldap op code default: //log.Printf("Bound as %s", boundDN) //ber.PrintPacket(packet) log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationBindRequest: server.stats.countBinds(1) ldapResultCode := server.handleBindRequest(req, server.bindFns, conn) if ldapResultCode == LDAPResultSuccess { boundDN = req.Children[1].Value.(string) } responsePacket := encodeBindResponse(messageID, ldapResultCode) if err = sendPacket(conn, responsePacket); err != nil { log.Printf("sendPacket error %s", err.Error()) break handler } case ApplicationSearchRequest: server.stats.countSearches(1) if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil { log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks? e := err.(*Error) if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil { log.Printf("sendPacket error %s", err.Error()) } break handler } else { if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil { log.Printf("sendPacket error %s", err.Error()) break handler } } case ApplicationUnbindRequest: server.stats.countUnbinds(1) break handler // simply disconnect - this IS implemented case ApplicationExtendedRequest: responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request") if err = sendPacket(conn, responsePacket); err != nil { log.Printf("sendPacket error %s", err.Error()) } break handler case ApplicationAbandonRequest: log.Printf("Abandoning request!") break handler // Unimplemented LDAP operations: case ApplicationModifyRequest: log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationAddRequest: log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationDelRequest: log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationModifyDNRequest: log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationCompareRequest: log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler } } for _, c := range server.closeFns { c.Close(conn) } conn.Close() } ///////////////////////// func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) { defer func() { if r := recover(); r != nil { resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) } }() searchReq, err := parseSearchRequest(boundDN, req) if err != nil { return NewError(LDAPResultOperationsError, err) } filterPacket, err := CompileFilter(searchReq.Filter) if err != nil { return NewError(LDAPResultOperationsError, err) } fnNames := []string{} for k := range searchFns { fnNames = append(fnNames, k) } searchFn := routeFunc(searchReq.BaseDN, fnNames) searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn) if err != nil { return NewError(uint8(searchResp.ResultCode), err) } if server.EnforceLDAP { if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases } if len(searchReq.Controls) > 0 { return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO } if searchReq.TimeLimit > 0 { return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO } } for i, entry := range searchResp.Entries { if server.EnforceLDAP { // size limit if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { break } // filter keep, resultCode := ServerApplyFilter(filterPacket, entry) if resultCode != LDAPResultSuccess { return NewError(uint8(resultCode), errors.New("ServerApplyFilter error")) } if !keep { continue } // constrained search scope switch searchReq.Scope { case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. if entry.DN != searchReq.BaseDN { continue } case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. parts := strings.Split(entry.DN, ",") if len(parts) < 2 && entry.DN != searchReq.BaseDN { continue } if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { continue } } // attributes if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { entry, err = filterAttributes(entry, searchReq.Attributes) if err != nil { return NewError(LDAPResultOperationsError, err) } } } // respond responsePacket := encodeSearchResponse(messageID, searchReq, entry) if err = sendPacket(conn, responsePacket); err != nil { return NewError(LDAPResultOperationsError, err) } } return nil } ///////////////////////// func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) { defer func() { if r := recover(); r != nil { resultCode = LDAPResultOperationsError } }() // we only support ldapv3 ldapVersion := req.Children[0].Value.(uint64) if ldapVersion != 3 { log.Printf("Unsupported LDAP version: %d", ldapVersion) return LDAPResultInappropriateAuthentication } // auth types bindDN := req.Children[1].Value.(string) bindAuth := req.Children[2] switch bindAuth.Tag { default: log.Print("Unknown LDAP authentication method") return LDAPResultInappropriateAuthentication case LDAPBindAuthSimple: if len(req.Children) == 3 { fnNames := []string{} for k := range bindFns { fnNames = append(fnNames, k) } bindFn := routeFunc(bindDN, fnNames) resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn) if err != nil { log.Printf("BindFn Error %s", err.Error()) } return resultCode } else { log.Print("Simple bind request has wrong # children. len(req.Children) != 3") return LDAPResultInappropriateAuthentication } case LDAPBindAuthSASL: log.Print("SASL authentication is not supported") return LDAPResultInappropriateAuthentication } return LDAPResultOperationsError } ///////////////////////// func sendPacket(conn net.Conn, packet *ber.Packet) error { _, err := conn.Write(packet.Bytes()) if err != nil { log.Printf("Error Sending Message: %s", err.Error()) return err } return nil } ///////////////////////// func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) { if len(req.Children) != 8 { return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) } // Parse the request baseObject := req.Children[0].Value.(string) scope := int(req.Children[1].Value.(uint64)) derefAliases := int(req.Children[2].Value.(uint64)) sizeLimit := int(req.Children[3].Value.(uint64)) timeLimit := int(req.Children[4].Value.(uint64)) typesOnly := false if req.Children[5].Value != nil { typesOnly = req.Children[5].Value.(bool) } filter, err := DecompileFilter(req.Children[6]) if err != nil { return SearchRequest{}, err } attributes := []string{} for _, attr := range req.Children[7].Children { attributes = append(attributes, attr.Value.(string)) } searchReq := SearchRequest{baseObject, scope, derefAliases, sizeLimit, timeLimit, typesOnly, filter, attributes, nil} return searchReq, nil } ///////////////////////// func routeFunc(dn string, funcNames []string) string { bestPick := "" for _, fn := range funcNames { if strings.HasSuffix(dn, fn) { l := len(strings.Split(bestPick, ",")) if bestPick == "" { l = 0 } if len(strings.Split(fn, ",")) > l { bestPick = fn } } } return bestPick } ///////////////////////// func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { // only return requested attributes newAttributes := []*EntryAttribute{} for _, attr := range entry.Attributes { for _, requested := range attributes { if strings.ToLower(attr.Name) == strings.ToLower(requested) { newAttributes = append(newAttributes, attr) } } } entry.Attributes = newAttributes return entry, nil } ///////////////////////// type defaultHandler struct { } func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { return LDAPResultInappropriateAuthentication, nil } func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil } func (h defaultHandler) Close(conn net.Conn) error { conn.Close() return nil } ///////////////////////// func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet { responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) responsePacket.AppendChild(bindReponse) // ber.PrintPacket(responsePacket) return responsePacket } func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") for _, attribute := range res.Attributes { attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) } searchEntry.AppendChild(attrs) responsePacket.AppendChild(searchEntry) return responsePacket } func encodeSearchAttribute(name string, values []string) *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") for _, value := range values { valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) } packet.AppendChild(valuesPacket) return packet } func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet { responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) responsePacket.AppendChild(donePacket) return responsePacket } func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet { responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) responsePacket.AppendChild(reponse) return responsePacket } ///////////////////////// func (stats *Stats) countConns(delta int) { if stats != nil { stats.statsMutex.Lock() stats.Conns += delta stats.statsMutex.Unlock() } } func (stats *Stats) countBinds(delta int) { if stats != nil { stats.statsMutex.Lock() stats.Binds += delta stats.statsMutex.Unlock() } } func (stats *Stats) countUnbinds(delta int) { if stats != nil { stats.statsMutex.Lock() stats.Unbinds += delta stats.statsMutex.Unlock() } } func (stats *Stats) countSearches(delta int) { if stats != nil { stats.statsMutex.Lock() stats.Searches += delta stats.statsMutex.Unlock() } } /////////////////////////