summaryrefslogtreecommitdiffstats
path: root/filter.go
diff options
context:
space:
mode:
Diffstat (limited to 'filter.go')
-rw-r--r--filter.go161
1 files changed, 147 insertions, 14 deletions
diff --git a/filter.go b/filter.go
index 690f67d..d7bc798 100644
--- a/filter.go
+++ b/filter.go
@@ -7,8 +7,8 @@ package ldap
import (
"errors"
"fmt"
-
- "github.com/vanackere/asn1-ber"
+ "github.com/nmcclain/asn1-ber"
+ "strings"
)
const (
@@ -24,7 +24,7 @@ const (
FilterExtensibleMatch = 9
)
-var filterMap = map[uint8]string{
+var FilterMap = map[uint8]string{
FilterAnd: "And",
FilterOr: "Or",
FilterNot: "Not",
@@ -163,15 +163,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
newPos++
return packet, newPos, err
case '&':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, filterMap[FilterAnd])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
newPos, err = compileFilterSet(filter, pos+1, packet)
return packet, newPos, err
case '|':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, filterMap[FilterOr])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
newPos, err = compileFilterSet(filter, pos+1, packet)
return packet, newPos, err
case '!':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, filterMap[FilterNot])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
var child *ber.Packet
child, newPos, err = compileFilter(filter, pos+1)
packet.AppendChild(child)
@@ -184,15 +184,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
case packet != nil:
condition += fmt.Sprintf("%c", filter[newPos])
case filter[newPos] == '=':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, filterMap[FilterEqualityMatch])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
case filter[newPos] == '>' && filter[newPos+1] == '=':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, filterMap[FilterGreaterOrEqual])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
newPos++
case filter[newPos] == '<' && filter[newPos+1] == '=':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, filterMap[FilterLessOrEqual])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
newPos++
case filter[newPos] == '~' && filter[newPos+1] == '=':
- packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, filterMap[FilterLessOrEqual])
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
newPos++
case packet == nil:
attribute += fmt.Sprintf("%c", filter[newPos])
@@ -211,7 +211,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
if packet.Tag == FilterEqualityMatch && condition == "*" {
packet.TagType = ber.TypePrimitive
packet.Tag = FilterPresent
- packet.Description = filterMap[packet.Tag]
+ packet.Description = FilterMap[packet.Tag]
packet.Data.WriteString(attribute)
return packet, newPos + 1, nil
}
@@ -220,21 +220,21 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
// Any
packet.Tag = FilterSubstrings
- packet.Description = filterMap[packet.Tag]
+ packet.Description = FilterMap[packet.Tag]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
packet.AppendChild(seq)
case packet.Tag == FilterEqualityMatch && condition[0] == '*':
// Final
packet.Tag = FilterSubstrings
- packet.Description = filterMap[packet.Tag]
+ packet.Description = FilterMap[packet.Tag]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
packet.AppendChild(seq)
case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
// Initial
packet.Tag = FilterSubstrings
- packet.Description = filterMap[packet.Tag]
+ packet.Description = FilterMap[packet.Tag]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
packet.AppendChild(seq)
@@ -245,3 +245,136 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
return packet, newPos, err
}
}
+
+func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) {
+ //log.Printf("%# v", pretty.Formatter(entry))
+
+ switch FilterMap[f.Tag] {
+ default:
+ //log.Fatalf("Unknown LDAP filter code: %d", f.Tag)
+ return false, LDAPResultOperationsError
+ case "Equality Match":
+ if len(f.Children) != 2 {
+ return false, LDAPResultOperationsError
+ }
+ attribute := f.Children[0].Value.(string)
+ value := f.Children[1].Value.(string)
+ for _, a := range entry.Attributes {
+ if strings.ToLower(a.Name) == strings.ToLower(attribute) {
+ for _, v := range a.Values {
+ if strings.ToLower(v) == strings.ToLower(value) {
+ return true, LDAPResultSuccess
+ }
+ }
+ }
+ }
+ case "Present":
+ for _, a := range entry.Attributes {
+ if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) {
+ return true, LDAPResultSuccess
+ }
+ }
+ case "And":
+ for _, child := range f.Children {
+ ok, exitCode := ServerApplyFilter(child, entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ }
+ if !ok {
+ return false, LDAPResultSuccess
+ }
+ }
+ return true, LDAPResultSuccess
+ case "Or":
+ anyOk := false
+ for _, child := range f.Children {
+ ok, exitCode := ServerApplyFilter(child, entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ } else if ok {
+ anyOk = true
+ }
+ }
+ if anyOk {
+ return true, LDAPResultSuccess
+ }
+ case "Not":
+ if len(f.Children) != 1 {
+ return false, LDAPResultOperationsError
+ }
+ ok, exitCode := ServerApplyFilter(f.Children[0], entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ } else if !ok {
+ return true, LDAPResultSuccess
+ }
+ case "FilterSubstrings":
+ return false, LDAPResultOperationsError
+ case "FilterGreaterOrEqual":
+ return false, LDAPResultOperationsError
+ case "FilterLessOrEqual":
+ return false, LDAPResultOperationsError
+ case "FilterApproxMatch":
+ return false, LDAPResultOperationsError
+ case "FilterExtensibleMatch":
+ return false, LDAPResultOperationsError
+ }
+
+ return false, LDAPResultSuccess
+}
+
+func GetFilterType(filter string) (string, error) { // TODO <- test this
+ f, err := CompileFilter(filter)
+ if err != nil {
+ return "", err
+ }
+ return parseFilterType(f)
+}
+func parseFilterType(f *ber.Packet) (string, error) {
+ searchType := ""
+ switch FilterMap[f.Tag] {
+ case "Equality Match":
+ if len(f.Children) != 2 {
+ return "", errors.New("Equality match must have only two children")
+ }
+ attribute := strings.ToLower(f.Children[0].Value.(string))
+ value := f.Children[1].Value.(string)
+
+ if attribute == "objectclass" {
+ searchType = strings.ToLower(value)
+ }
+ case "And":
+ for _, child := range f.Children {
+ subType, err := parseFilterType(child)
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ searchType = subType
+ }
+ }
+ case "Or":
+ for _, child := range f.Children {
+ subType, err := parseFilterType(child)
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ searchType = subType
+ }
+ }
+ case "Not":
+ if len(f.Children) != 1 {
+ return "", errors.New("Not filter must have only one child")
+ }
+ subType, err := parseFilterType(f.Children[0])
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ searchType = subType
+ }
+
+ }
+ return strings.ToLower(searchType), nil
+}