summaryrefslogblamecommitdiffstats
path: root/filter.go
blob: 66c8afeeb354735e020118fda8b3683dda4993dc (plain) (tree)
1
2
3
4
5
6
7
8
9







                                                      
                
             
                                     


       









                                 

 










                                                  


       


                                   

 



                                                      

 











                                                                                                                                               

 








                                                                                                    
 

























                                                                    
 






























                                                                                                        
 

                  

 











                                                                                                
 
                           

 








                                                                                                
 



















































































                                                                                                                                                                       
 
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// File contains a filter compiler/decompiler
package ldap

import (
	"errors"
	"fmt"
	"github.com/tmfkams/asn1-ber"
)

const (
	FilterAnd             = 0
	FilterOr              = 1
	FilterNot             = 2
	FilterEqualityMatch   = 3
	FilterSubstrings      = 4
	FilterGreaterOrEqual  = 5
	FilterLessOrEqual     = 6
	FilterPresent         = 7
	FilterApproxMatch     = 8
	FilterExtensibleMatch = 9
)

var FilterMap = map[uint64]string{
	FilterAnd:             "And",
	FilterOr:              "Or",
	FilterNot:             "Not",
	FilterEqualityMatch:   "Equality Match",
	FilterSubstrings:      "Substrings",
	FilterGreaterOrEqual:  "Greater Or Equal",
	FilterLessOrEqual:     "Less Or Equal",
	FilterPresent:         "Present",
	FilterApproxMatch:     "Approx Match",
	FilterExtensibleMatch: "Extensible Match",
}

const (
	FilterSubstringsInitial = 0
	FilterSubstringsAny     = 1
	FilterSubstringsFinal   = 2
)

var FilterSubstringsMap = map[uint64]string{
	FilterSubstringsInitial: "Substrings Initial",
	FilterSubstringsAny:     "Substrings Any",
	FilterSubstringsFinal:   "Substrings Final",
}

func CompileFilter(filter string) (*ber.Packet, *Error) {
	if len(filter) == 0 || filter[0] != '(' {
		return nil, NewError(ErrorFilterCompile, errors.New("Filter does not start with an '('"))
	}
	packet, pos, err := compileFilter(filter, 1)
	if err != nil {
		return nil, err
	}
	if pos != len(filter) {
		return nil, NewError(ErrorFilterCompile, errors.New("Finished compiling filter with extra at end.\n"+fmt.Sprint(filter[pos:])))
	}
	return packet, nil
}

func DecompileFilter(packet *ber.Packet) (ret string, err *Error) {
	defer func() {
		if r := recover(); r != nil {
			err = NewError(ErrorFilterDecompile, errors.New("Error decompiling filter"))
		}
	}()
	ret = "("
	err = nil
	child_str := ""

	switch packet.Tag {
	case FilterAnd:
		ret += "&"
		for _, child := range packet.Children {
			child_str, err = DecompileFilter(child)
			if err != nil {
				return
			}
			ret += child_str
		}
	case FilterOr:
		ret += "|"
		for _, child := range packet.Children {
			child_str, err = DecompileFilter(child)
			if err != nil {
				return
			}
			ret += child_str
		}
	case FilterNot:
		ret += "!"
		child_str, err = DecompileFilter(packet.Children[0])
		if err != nil {
			return
		}
		ret += child_str

	case FilterSubstrings:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += "="
		switch packet.Children[1].Children[0].Tag {
		case FilterSubstringsInitial:
			ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
		case FilterSubstringsAny:
			ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
		case FilterSubstringsFinal:
			ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
		}
	case FilterEqualityMatch:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += "="
		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
	case FilterGreaterOrEqual:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += ">="
		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
	case FilterLessOrEqual:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += "<="
		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
	case FilterPresent:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += "=*"
	case FilterApproxMatch:
		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
		ret += "~="
		ret += ber.DecodeString(packet.Children[1].Data.Bytes())
	}

	ret += ")"
	return
}

func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, *Error) {
	for pos < len(filter) && filter[pos] == '(' {
		child, new_pos, err := compileFilter(filter, pos+1)
		if err != nil {
			return pos, err
		}
		pos = new_pos
		parent.AppendChild(child)
	}
	if pos == len(filter) {
		return pos, NewError(ErrorFilterCompile, errors.New("Unexpected end of filter"))
	}

	return pos + 1, nil
}

func compileFilter(filter string, pos int) (p *ber.Packet, new_pos int, err *Error) {
	defer func() {
		if r := recover(); r != nil {
			err = NewError(ErrorFilterCompile, errors.New("Error compiling filter"))
		}
	}()
	p = nil
	new_pos = pos
	err = nil

	switch filter[pos] {
	case '(':
		p, new_pos, err = compileFilter(filter, pos+1)
		new_pos++
		return
	case '&':
		p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
		new_pos, err = compileFilterSet(filter, pos+1, p)
		return
	case '|':
		p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
		new_pos, err = compileFilterSet(filter, pos+1, p)
		return
	case '!':
		p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
		var child *ber.Packet
		child, new_pos, err = compileFilter(filter, pos+1)
		p.AppendChild(child)
		return
	default:
		attribute := ""
		condition := ""
		for new_pos < len(filter) && filter[new_pos] != ')' {
			switch {
			case p != nil:
				condition += fmt.Sprintf("%c", filter[new_pos])
			case filter[new_pos] == '=':
				p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
			case filter[new_pos] == '>' && filter[new_pos+1] == '=':
				p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
				new_pos++
			case filter[new_pos] == '<' && filter[new_pos+1] == '=':
				p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
				new_pos++
			case filter[new_pos] == '~' && filter[new_pos+1] == '=':
				p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
				new_pos++
			case p == nil:
				attribute += fmt.Sprintf("%c", filter[new_pos])
			}
			new_pos++
		}
		if new_pos == len(filter) {
			err = NewError(ErrorFilterCompile, errors.New("Unexpected end of filter"))
			return
		}
		if p == nil {
			err = NewError(ErrorFilterCompile, errors.New("Error parsing filter"))
			return
		}
		p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute"))
		switch {
		case p.Tag == FilterEqualityMatch && condition == "*":
			p.Tag = FilterPresent
			p.Description = FilterMap[uint64(p.Tag)]
		case p.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
			// Any
			p.Tag = FilterSubstrings
			p.Description = FilterMap[uint64(p.Tag)]
			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
			p.AppendChild(seq)
		case p.Tag == FilterEqualityMatch && condition[0] == '*':
			// Final
			p.Tag = FilterSubstrings
			p.Description = FilterMap[uint64(p.Tag)]
			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[1:], "Final Substring"))
			p.AppendChild(seq)
		case p.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
			// Initial
			p.Tag = FilterSubstrings
			p.Description = FilterMap[uint64(p.Tag)]
			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
			seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
			p.AppendChild(seq)
		default:
			p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition"))
		}
		new_pos++
		return
	}
	err = NewError(ErrorFilterCompile, errors.New("Reached end of filter without closing parens"))
	return
}