summaryrefslogblamecommitdiffstats
path: root/conn.go
blob: 18d778cd4840dd1e0b4f0e144185a4f7499d020a (plain) (tree)













































































































































































































































































                                                                                                                                    
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// This package provides LDAP client functions.
package ldap

import (
   "github.com/mmitton/asn1-ber"
   "crypto/tls"
   "fmt"
   "net"
   "os"
)

// LDAP Connection
type Conn struct {
	conn net.Conn
   isSSL bool
   Debug bool

   chanResults map[ uint64 ] chan *ber.Packet
   chanProcessMessage chan *messagePacket
   chanMessageID chan uint64
}

// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, *Error) {
	c, err := net.Dial(network, "", addr)
	if err != nil {
		return nil, NewError( ErrorNetwork, err )
	}
   conn := NewConn(c)
   conn.start()
	return conn, nil
}

// Dial connects to the given address on the given network using net.Dial
// and then sets up SSL connection and returns a new Conn for the connection.
func DialSSL(network, addr string) (*Conn, *Error) {
	c, err := tls.Dial(network, "", addr, nil)
	if err != nil {
		return nil, NewError( ErrorNetwork, err )
	}
   conn := NewConn(c)
   conn.isSSL = true

   conn.start()
	return conn, nil
}

// Dial connects to the given address on the given network using net.Dial
// and then starts a TLS session and returns a new Conn for the connection.
func DialTLS(network, addr string) (*Conn, *Error) {
	c, err := net.Dial(network, "", addr)
	if err != nil {
		return nil, NewError( ErrorNetwork, err )
	}
   conn := NewConn(c)

   err = conn.startTLS()
   if err != nil {
      conn.Close()
      return nil, NewError( ErrorNetwork, err )
   }
   conn.start()
	return conn, nil
}

// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn) *Conn {
	return &Conn{
		conn: conn,
      isSSL: false,
      Debug: false,
      chanResults: map[uint64] chan *ber.Packet{},
      chanProcessMessage: make( chan *messagePacket ),
      chanMessageID: make( chan uint64 ),
	}
}

func (l *Conn) start() {
   go l.reader()
   go l.processMessages()
}

// Close closes the connection.
func (l *Conn) Close() *Error {
   if l.chanProcessMessage != nil {
      message_packet := &messagePacket{ Op: MessageQuit }
      l.chanProcessMessage <- message_packet
      l.chanProcessMessage = nil
   }

   if l.conn != nil {
      err := l.conn.Close()
      if err != nil {
         return NewError( ErrorNetwork, err )
      }
      l.conn = nil
   }
	return nil
}

// Returns the next available messageID
func (l *Conn) nextMessageID() uint64 {
   messageID := <-l.chanMessageID
   return messageID
}

// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) startTLS() *Error {
   messageID := l.nextMessageID()

   if l.isSSL {
      return NewError( ErrorNetwork, os.NewError( "Already encrypted" ) )
   }

   packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" )
   packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) )
   startTLS := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS" )
   startTLS.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command" ) )
   packet.AppendChild( startTLS )
   if l.Debug {
      ber.PrintPacket( packet )
   }

   _, err := l.conn.Write( packet.Bytes() )
   if err != nil {
      return NewError( ErrorNetwork, err )
   }

   packet, err = ber.ReadPacket( l.conn )
   if err != nil {
      return NewError( ErrorNetwork, err )
   }

   if l.Debug {
      if err := addLDAPDescriptions( packet ); err != nil {
         return NewError( ErrorDebugging, err )
      }
      ber.PrintPacket( packet )
   }

   if packet.Children[ 1 ].Children[ 0 ].Value.(uint64) == 0 {
      conn := tls.Client( l.conn, nil )
      l.isSSL = true
      l.conn = conn
   }

   return nil
}

const (
   MessageQuit = 0
   MessageRequest = 1
   MessageResponse = 2
   MessageFinish = 3
)

type messagePacket struct {
   Op int
   MessageID uint64
   Packet *ber.Packet
   Channel chan *ber.Packet
}

func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) {
   message_id := p.Children[ 0 ].Value.(uint64)
   out = make(chan *ber.Packet)

   message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out }
   if l.chanProcessMessage == nil {
      err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) )
      return
   }
   l.chanProcessMessage <- message_packet
   return
}

func (l *Conn) processMessages() {
   defer l.closeAllChannels()

   var message_id uint64 = 1
   var message_packet *messagePacket
   for {
      select {
         case l.chanMessageID <- message_id:
            if l.conn == nil {
               return
            }
            message_id++
         case message_packet = <-l.chanProcessMessage:
            if l.conn == nil {
               return
            }
            switch message_packet.Op {
               case MessageQuit:
                  // Close all channels and quit
                  if l.Debug {
                     fmt.Printf( "Shutting down\n" )
                  }
                  return
               case MessageRequest:
                  // Add to message list and write to network
                  if l.Debug {
                     fmt.Printf( "Sending message %d\n", message_packet.MessageID )
                  }
                  l.chanResults[ message_packet.MessageID ] = message_packet.Channel
                  l.conn.Write( message_packet.Packet.Bytes() )
               case MessageResponse:
                  // Pass back to waiting goroutine
                  if l.Debug {
                     fmt.Printf( "Receiving message %d\n", message_packet.MessageID )
                  }
                  chanResult := l.chanResults[ message_packet.MessageID ]
                  if chanResult == nil {
                     fmt.Printf( "Unexpected Message Result: %d", message_id )
                  } else {
                     chanResult <- message_packet.Packet
                  }
               case MessageFinish:
                  // Remove from message list
                  if l.Debug {
                     fmt.Printf( "Finished message %d\n", message_packet.MessageID )
                  }
                  l.chanResults[ message_packet.MessageID ] = nil, false
            }
      }
   }
}

func (l *Conn) closeAllChannels() {
   for MessageID, Channel := range l.chanResults {
      if l.Debug {
         fmt.Printf( "Closing channel for MessageID %d\n", MessageID );
      }
      close( Channel )
      l.chanResults[ MessageID ] = nil, false
   }
   close( l.chanMessageID )
   l.chanMessageID = nil
}

func (l *Conn) finishMessage( MessageID uint64 ) {
   message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID }
   if l.chanProcessMessage != nil {
      l.chanProcessMessage <- message_packet
   }
}

func (l *Conn) reader() {
   for {
      p, err := ber.ReadPacket( l.conn )
      if err != nil {
         if l.Debug {
            fmt.Printf( "ldap.reader: %s\n", err.String() )
         }
         break
      }

      message_id := p.Children[ 0 ].Value.(uint64)
      message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p }
      l.chanProcessMessage <- message_packet
   }

   l.Close()
}