// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This package provides LDAP client functions. package ldap import ( "github.com/mmitton/asn1-ber" "crypto/tls" "fmt" "net" "os" ) // LDAP Connection type Conn struct { conn net.Conn isSSL bool Debug bool chanResults map[ uint64 ] chan *ber.Packet chanProcessMessage chan *messagePacket chanMessageID chan uint64 } // Dial connects to the given address on the given network using net.Dial // and then returns a new Conn for the connection. func Dial(network, addr string) (*Conn, *Error) { c, err := net.Dial(network, "", addr) if err != nil { return nil, NewError( ErrorNetwork, err ) } conn := NewConn(c) conn.start() return conn, nil } // Dial connects to the given address on the given network using net.Dial // and then sets up SSL connection and returns a new Conn for the connection. func DialSSL(network, addr string) (*Conn, *Error) { c, err := tls.Dial(network, "", addr, nil) if err != nil { return nil, NewError( ErrorNetwork, err ) } conn := NewConn(c) conn.isSSL = true conn.start() return conn, nil } // Dial connects to the given address on the given network using net.Dial // and then starts a TLS session and returns a new Conn for the connection. func DialTLS(network, addr string) (*Conn, *Error) { c, err := net.Dial(network, "", addr) if err != nil { return nil, NewError( ErrorNetwork, err ) } conn := NewConn(c) err = conn.startTLS() if err != nil { conn.Close() return nil, NewError( ErrorNetwork, err ) } conn.start() return conn, nil } // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn) *Conn { return &Conn{ conn: conn, isSSL: false, Debug: false, chanResults: map[uint64] chan *ber.Packet{}, chanProcessMessage: make( chan *messagePacket ), chanMessageID: make( chan uint64 ), } } func (l *Conn) start() { go l.reader() go l.processMessages() } // Close closes the connection. func (l *Conn) Close() *Error { if l.chanProcessMessage != nil { message_packet := &messagePacket{ Op: MessageQuit } l.chanProcessMessage <- message_packet l.chanProcessMessage = nil } if l.conn != nil { err := l.conn.Close() if err != nil { return NewError( ErrorNetwork, err ) } l.conn = nil } return nil } // Returns the next available messageID func (l *Conn) nextMessageID() uint64 { messageID := <-l.chanMessageID return messageID } // StartTLS sends the command to start a TLS session and then creates a new TLS Client func (l *Conn) startTLS() *Error { messageID := l.nextMessageID() if l.isSSL { return NewError( ErrorNetwork, os.NewError( "Already encrypted" ) ) } packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) startTLS := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS" ) startTLS.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command" ) ) packet.AppendChild( startTLS ) if l.Debug { ber.PrintPacket( packet ) } _, err := l.conn.Write( packet.Bytes() ) if err != nil { return NewError( ErrorNetwork, err ) } packet, err = ber.ReadPacket( l.conn ) if err != nil { return NewError( ErrorNetwork, err ) } if l.Debug { if err := addLDAPDescriptions( packet ); err != nil { return NewError( ErrorDebugging, err ) } ber.PrintPacket( packet ) } if packet.Children[ 1 ].Children[ 0 ].Value.(uint64) == 0 { conn := tls.Client( l.conn, nil ) l.isSSL = true l.conn = conn } return nil } const ( MessageQuit = 0 MessageRequest = 1 MessageResponse = 2 MessageFinish = 3 ) type messagePacket struct { Op int MessageID uint64 Packet *ber.Packet Channel chan *ber.Packet } func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) { message_id := p.Children[ 0 ].Value.(uint64) out = make(chan *ber.Packet) message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out } if l.chanProcessMessage == nil { err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) ) return } l.chanProcessMessage <- message_packet return } func (l *Conn) processMessages() { defer l.closeAllChannels() var message_id uint64 = 1 var message_packet *messagePacket for { select { case l.chanMessageID <- message_id: if l.conn == nil { return } message_id++ case message_packet = <-l.chanProcessMessage: if l.conn == nil { return } switch message_packet.Op { case MessageQuit: // Close all channels and quit if l.Debug { fmt.Printf( "Shutting down\n" ) } return case MessageRequest: // Add to message list and write to network if l.Debug { fmt.Printf( "Sending message %d\n", message_packet.MessageID ) } l.chanResults[ message_packet.MessageID ] = message_packet.Channel l.conn.Write( message_packet.Packet.Bytes() ) case MessageResponse: // Pass back to waiting goroutine if l.Debug { fmt.Printf( "Receiving message %d\n", message_packet.MessageID ) } chanResult := l.chanResults[ message_packet.MessageID ] if chanResult == nil { fmt.Printf( "Unexpected Message Result: %d", message_id ) } else { chanResult <- message_packet.Packet } case MessageFinish: // Remove from message list if l.Debug { fmt.Printf( "Finished message %d\n", message_packet.MessageID ) } l.chanResults[ message_packet.MessageID ] = nil, false } } } } func (l *Conn) closeAllChannels() { for MessageID, Channel := range l.chanResults { if l.Debug { fmt.Printf( "Closing channel for MessageID %d\n", MessageID ); } close( Channel ) l.chanResults[ MessageID ] = nil, false } close( l.chanMessageID ) l.chanMessageID = nil } func (l *Conn) finishMessage( MessageID uint64 ) { message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID } if l.chanProcessMessage != nil { l.chanProcessMessage <- message_packet } } func (l *Conn) reader() { for { p, err := ber.ReadPacket( l.conn ) if err != nil { if l.Debug { fmt.Printf( "ldap.reader: %s\n", err.String() ) } break } message_id := p.Children[ 0 ].Value.(uint64) message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p } l.chanProcessMessage <- message_packet } l.Close() }