summaryrefslogtreecommitdiffstats
path: root/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn.go')
-rw-r--r--conn.go270
1 files changed, 270 insertions, 0 deletions
diff --git a/conn.go b/conn.go
new file mode 100644
index 0000000..18d778c
--- /dev/null
+++ b/conn.go
@@ -0,0 +1,270 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package provides LDAP client functions.
+package ldap
+
+import (
+ "github.com/mmitton/asn1-ber"
+ "crypto/tls"
+ "fmt"
+ "net"
+ "os"
+)
+
+// LDAP Connection
+type Conn struct {
+ conn net.Conn
+ isSSL bool
+ Debug bool
+
+ chanResults map[ uint64 ] chan *ber.Packet
+ chanProcessMessage chan *messagePacket
+ chanMessageID chan uint64
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then returns a new Conn for the connection.
+func Dial(network, addr string) (*Conn, *Error) {
+ c, err := net.Dial(network, "", addr)
+ if err != nil {
+ return nil, NewError( ErrorNetwork, err )
+ }
+ conn := NewConn(c)
+ conn.start()
+ return conn, nil
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then sets up SSL connection and returns a new Conn for the connection.
+func DialSSL(network, addr string) (*Conn, *Error) {
+ c, err := tls.Dial(network, "", addr, nil)
+ if err != nil {
+ return nil, NewError( ErrorNetwork, err )
+ }
+ conn := NewConn(c)
+ conn.isSSL = true
+
+ conn.start()
+ return conn, nil
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then starts a TLS session and returns a new Conn for the connection.
+func DialTLS(network, addr string) (*Conn, *Error) {
+ c, err := net.Dial(network, "", addr)
+ if err != nil {
+ return nil, NewError( ErrorNetwork, err )
+ }
+ conn := NewConn(c)
+
+ err = conn.startTLS()
+ if err != nil {
+ conn.Close()
+ return nil, NewError( ErrorNetwork, err )
+ }
+ conn.start()
+ return conn, nil
+}
+
+// NewConn returns a new Conn using conn for network I/O.
+func NewConn(conn net.Conn) *Conn {
+ return &Conn{
+ conn: conn,
+ isSSL: false,
+ Debug: false,
+ chanResults: map[uint64] chan *ber.Packet{},
+ chanProcessMessage: make( chan *messagePacket ),
+ chanMessageID: make( chan uint64 ),
+ }
+}
+
+func (l *Conn) start() {
+ go l.reader()
+ go l.processMessages()
+}
+
+// Close closes the connection.
+func (l *Conn) Close() *Error {
+ if l.chanProcessMessage != nil {
+ message_packet := &messagePacket{ Op: MessageQuit }
+ l.chanProcessMessage <- message_packet
+ l.chanProcessMessage = nil
+ }
+
+ if l.conn != nil {
+ err := l.conn.Close()
+ if err != nil {
+ return NewError( ErrorNetwork, err )
+ }
+ l.conn = nil
+ }
+ return nil
+}
+
+// Returns the next available messageID
+func (l *Conn) nextMessageID() uint64 {
+ messageID := <-l.chanMessageID
+ return messageID
+}
+
+// StartTLS sends the command to start a TLS session and then creates a new TLS Client
+func (l *Conn) startTLS() *Error {
+ messageID := l.nextMessageID()
+
+ if l.isSSL {
+ return NewError( ErrorNetwork, os.NewError( "Already encrypted" ) )
+ }
+
+ packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" )
+ packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) )
+ startTLS := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS" )
+ startTLS.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command" ) )
+ packet.AppendChild( startTLS )
+ if l.Debug {
+ ber.PrintPacket( packet )
+ }
+
+ _, err := l.conn.Write( packet.Bytes() )
+ if err != nil {
+ return NewError( ErrorNetwork, err )
+ }
+
+ packet, err = ber.ReadPacket( l.conn )
+ if err != nil {
+ return NewError( ErrorNetwork, err )
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions( packet ); err != nil {
+ return NewError( ErrorDebugging, err )
+ }
+ ber.PrintPacket( packet )
+ }
+
+ if packet.Children[ 1 ].Children[ 0 ].Value.(uint64) == 0 {
+ conn := tls.Client( l.conn, nil )
+ l.isSSL = true
+ l.conn = conn
+ }
+
+ return nil
+}
+
+const (
+ MessageQuit = 0
+ MessageRequest = 1
+ MessageResponse = 2
+ MessageFinish = 3
+)
+
+type messagePacket struct {
+ Op int
+ MessageID uint64
+ Packet *ber.Packet
+ Channel chan *ber.Packet
+}
+
+func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) {
+ message_id := p.Children[ 0 ].Value.(uint64)
+ out = make(chan *ber.Packet)
+
+ message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out }
+ if l.chanProcessMessage == nil {
+ err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) )
+ return
+ }
+ l.chanProcessMessage <- message_packet
+ return
+}
+
+func (l *Conn) processMessages() {
+ defer l.closeAllChannels()
+
+ var message_id uint64 = 1
+ var message_packet *messagePacket
+ for {
+ select {
+ case l.chanMessageID <- message_id:
+ if l.conn == nil {
+ return
+ }
+ message_id++
+ case message_packet = <-l.chanProcessMessage:
+ if l.conn == nil {
+ return
+ }
+ switch message_packet.Op {
+ case MessageQuit:
+ // Close all channels and quit
+ if l.Debug {
+ fmt.Printf( "Shutting down\n" )
+ }
+ return
+ case MessageRequest:
+ // Add to message list and write to network
+ if l.Debug {
+ fmt.Printf( "Sending message %d\n", message_packet.MessageID )
+ }
+ l.chanResults[ message_packet.MessageID ] = message_packet.Channel
+ l.conn.Write( message_packet.Packet.Bytes() )
+ case MessageResponse:
+ // Pass back to waiting goroutine
+ if l.Debug {
+ fmt.Printf( "Receiving message %d\n", message_packet.MessageID )
+ }
+ chanResult := l.chanResults[ message_packet.MessageID ]
+ if chanResult == nil {
+ fmt.Printf( "Unexpected Message Result: %d", message_id )
+ } else {
+ chanResult <- message_packet.Packet
+ }
+ case MessageFinish:
+ // Remove from message list
+ if l.Debug {
+ fmt.Printf( "Finished message %d\n", message_packet.MessageID )
+ }
+ l.chanResults[ message_packet.MessageID ] = nil, false
+ }
+ }
+ }
+}
+
+func (l *Conn) closeAllChannels() {
+ for MessageID, Channel := range l.chanResults {
+ if l.Debug {
+ fmt.Printf( "Closing channel for MessageID %d\n", MessageID );
+ }
+ close( Channel )
+ l.chanResults[ MessageID ] = nil, false
+ }
+ close( l.chanMessageID )
+ l.chanMessageID = nil
+}
+
+func (l *Conn) finishMessage( MessageID uint64 ) {
+ message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID }
+ if l.chanProcessMessage != nil {
+ l.chanProcessMessage <- message_packet
+ }
+}
+
+func (l *Conn) reader() {
+ for {
+ p, err := ber.ReadPacket( l.conn )
+ if err != nil {
+ if l.Debug {
+ fmt.Printf( "ldap.reader: %s\n", err.String() )
+ }
+ break
+ }
+
+ message_id := p.Children[ 0 ].Value.(uint64)
+ message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p }
+ l.chanProcessMessage <- message_packet
+ }
+
+ l.Close()
+}
+