summaryrefslogblamecommitdiffstats
path: root/filter.go
blob: 1b8fd1e482f8bdf20a6cfe5eabbdd1075dd2aff8 (plain) (tree)
























































































































































































































































                                                                                                                                                                        
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// File contains a filter compiler/decompiler
package ldap

import (
	"fmt"
   "os"
   "github.com/mmitton/asn1-ber"
)

const (
   FilterAnd             = 0
   FilterOr              = 1
   FilterNot             = 2
   FilterEqualityMatch   = 3
   FilterSubstrings      = 4
   FilterGreaterOrEqual  = 5
   FilterLessOrEqual     = 6
   FilterPresent         = 7
   FilterApproxMatch     = 8
   FilterExtensibleMatch = 9
)

var FilterMap = map[ uint64 ] string {
   FilterAnd             : "And",
   FilterOr              : "Or",
   FilterNot             : "Not",
   FilterEqualityMatch   : "Equality Match",
   FilterSubstrings      : "Substrings",
   FilterGreaterOrEqual  : "Greater Or Equal",
   FilterLessOrEqual     : "Less Or Equal",
   FilterPresent         : "Present",
   FilterApproxMatch     : "Approx Match",
   FilterExtensibleMatch : "Extensible Match",
}

const (
   FilterSubstringsInitial = 0
   FilterSubstringsAny     = 1
   FilterSubstringsFinal   = 2
)

var FilterSubstringsMap = map[ uint64 ] string {
   FilterSubstringsInitial : "Substrings Initial",
   FilterSubstringsAny     : "Substrings Any",
   FilterSubstringsFinal   : "Substrings Final",
}

func CompileFilter( filter string ) ( *ber.Packet, *Error ) {
   if len( filter ) == 0 || filter[ 0 ] != '(' {
      return nil, NewError( ErrorFilterCompile, os.NewError( "Filter does not start with an '('" ) )
   }
   packet, pos, err := compileFilter( filter, 1 )
   if err != nil {
      return nil, err
   }
   if pos != len( filter ) {
      return nil, NewError( ErrorFilterCompile, os.NewError( "Finished compiling filter with extra at end.\n" + fmt.Sprint( filter[pos:] ) ) )
   }
   return packet, nil
}

func DecompileFilter( packet *ber.Packet ) (ret string, err *Error) {
   defer func() {
      if r := recover(); r != nil {
         err = NewError( ErrorFilterDecompile, os.NewError( "Error decompiling filter" ) )
      }
   }()
   ret = "("
   err = nil
   child_str := ""

   switch packet.Tag {
      case FilterAnd:
         ret += "&"
         for _, child := range packet.Children {
            child_str, err = DecompileFilter( child )
            if err != nil {
               return
            }
            ret += child_str
         }
      case FilterOr:
         ret += "|"
         for _, child := range packet.Children {
            child_str, err = DecompileFilter( child )
            if err != nil {
               return
            }
            ret += child_str
         }
      case FilterNot:
         ret += "!"
         child_str, err = DecompileFilter( packet.Children[ 0 ] )
         if err != nil {
            return
         }
         ret += child_str

      case FilterSubstrings:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += "="
         switch packet.Children[ 1 ].Children[ 0 ].Tag {
            case FilterSubstringsInitial:
               ret += ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*"
            case FilterSubstringsAny:
               ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*"
            case FilterSubstringsFinal:
               ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() )
         }
      case FilterEqualityMatch:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += "="
         ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() )
      case FilterGreaterOrEqual:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += ">="
         ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() )
      case FilterLessOrEqual:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += "<="
         ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() )
      case FilterPresent:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += "=*"
      case FilterApproxMatch:
         ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() )
         ret += "~="
         ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() )
   }

   ret += ")"
   return
}

func compileFilterSet( filter string, pos int, parent *ber.Packet ) ( int, *Error ) {
   for pos < len( filter ) && filter[ pos ] == '(' {
      child, new_pos, err := compileFilter( filter, pos + 1 )
      if err != nil {
         return pos, err
      }
      pos = new_pos
      parent.AppendChild( child )
   }
   if pos == len( filter ) {
      return pos, NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) )
   }

   return pos + 1, nil
}

func compileFilter( filter string, pos int ) ( p *ber.Packet, new_pos int, err *Error ) {
   defer func() {
      if r := recover(); r != nil {
         err = NewError( ErrorFilterCompile, os.NewError( "Error compiling filter" ) )
      }
   }()
   p = nil
   new_pos = pos
   err = nil

   switch filter[pos] {
      case '(':
         p, new_pos, err = compileFilter( filter, pos + 1 )
         new_pos++
         return
      case '&':
         p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[ FilterAnd ] )
         new_pos, err = compileFilterSet( filter, pos + 1, p )
         return
      case '|':
         p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[ FilterOr ] )
         new_pos, err = compileFilterSet( filter, pos + 1, p )
         return
      case '!':
         p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[ FilterNot ] )
         var child *ber.Packet
         child, new_pos, err = compileFilter( filter, pos + 1 )
         p.AppendChild( child )
         return
      default:
         attribute := ""
         condition := ""
         for new_pos < len( filter ) && filter[ new_pos ] != ')' {
            switch {
               case p != nil:
                  condition += fmt.Sprintf( "%c", filter[ new_pos ] )
               case filter[ new_pos ] == '=':
                  p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[ FilterEqualityMatch ] )
               case filter[ new_pos ] == '>' && filter[ new_pos + 1 ] == '=':
                  p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[ FilterGreaterOrEqual ] )
                  new_pos++
               case filter[ new_pos ] == '<' && filter[ new_pos + 1 ] == '=':
                  p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[ FilterLessOrEqual ] )
                  new_pos++
               case filter[ new_pos ] == '~' && filter[ new_pos + 1 ] == '=':
                  p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[ FilterLessOrEqual ] )
                  new_pos++
               case p == nil:
                  attribute += fmt.Sprintf( "%c", filter[ new_pos ] )
            }
            new_pos++
         }
         if new_pos == len( filter ) {
            err = NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) )
            return
         }
         if p == nil {
            err = NewError( ErrorFilterCompile, os.NewError( "Error parsing filter" ) )
            return
         }
         p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute" ) )
         switch {
            case p.Tag == FilterEqualityMatch && condition == "*":
               p.Tag = FilterPresent
               p.Description = FilterMap[ uint64(p.Tag) ]
            case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*' && condition[ len( condition ) - 1 ] == '*':
               // Any
               p.Tag = FilterSubstrings
               p.Description = FilterMap[ uint64(p.Tag) ]
               seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" )
               seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[ 1 : len( condition ) - 1 ], "Any Substring" ) )
               p.AppendChild( seq )
            case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*':
               // Final
               p.Tag = FilterSubstrings
               p.Description = FilterMap[ uint64(p.Tag) ]
               seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" )
               seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[ 1: ], "Final Substring" ) )
               p.AppendChild( seq )
            case p.Tag == FilterEqualityMatch && condition[ len( condition ) - 1 ] == '*':
               // Initial
               p.Tag = FilterSubstrings
               p.Description = FilterMap[ uint64(p.Tag) ]
               seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" )
               seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[ :len( condition ) - 1 ], "Initial Substring" ) )
               p.AppendChild( seq )
            default:
               p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition" ) )
         }
         new_pos++
         return
   }
   err = NewError( ErrorFilterCompile, os.NewError( "Reached end of filter without closing parens" ) )
   return
}