From 611f66a4923b5ba48bb127943385180088dc1e2f Mon Sep 17 00:00:00 2001 From: Michael Mitton Date: Fri, 18 Feb 2011 13:20:30 -0500 Subject: Fixed message processing deadlocks and added mutex for closing function --- conn.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index 18d778c..41e69fb 100644 --- a/conn.go +++ b/conn.go @@ -11,6 +11,7 @@ import ( "fmt" "net" "os" + "sync" ) // LDAP Connection @@ -22,6 +23,8 @@ type Conn struct { chanResults map[ uint64 ] chan *ber.Packet chanProcessMessage chan *messagePacket chanMessageID chan uint64 + + closeLock sync.Mutex } // Dial connects to the given address on the given network using net.Dial @@ -87,11 +90,10 @@ func (l *Conn) start() { // Close closes the connection. func (l *Conn) Close() *Error { - if l.chanProcessMessage != nil { - message_packet := &messagePacket{ Op: MessageQuit } - l.chanProcessMessage <- message_packet - l.chanProcessMessage = nil - } + l.closeLock.Lock() + defer l.closeLock.Unlock() + + l.sendProcessMessage( &messagePacket{ Op: MessageQuit } ) if l.conn != nil { err := l.conn.Close() @@ -104,9 +106,10 @@ func (l *Conn) Close() *Error { } // Returns the next available messageID -func (l *Conn) nextMessageID() uint64 { - messageID := <-l.chanMessageID - return messageID +func (l *Conn) nextMessageID() (messageID uint64) { + defer func() { if r := recover(); r != nil { messageID = 0 } }() + messageID = <-l.chanMessageID + return } // StartTLS sends the command to start a TLS session and then creates a new TLS Client @@ -170,12 +173,12 @@ func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) { message_id := p.Children[ 0 ].Value.(uint64) out = make(chan *ber.Packet) - message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out } if l.chanProcessMessage == nil { err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) ) return } - l.chanProcessMessage <- message_packet + message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out } + l.sendProcessMessage( message_packet ) return } @@ -208,7 +211,20 @@ func (l *Conn) processMessages() { fmt.Printf( "Sending message %d\n", message_packet.MessageID ) } l.chanResults[ message_packet.MessageID ] = message_packet.Channel - l.conn.Write( message_packet.Packet.Bytes() ) + buf := message_packet.Packet.Bytes() + for len( buf ) > 0 { + n, err := l.conn.Write( buf ) + if err != nil { + if l.Debug { + fmt.Printf( "Error Sending Message: %s\n", err.String() ) + } + return + } + if n == len( buf ) { + break + } + buf = buf[n:] + } case MessageResponse: // Pass back to waiting goroutine if l.Debug { @@ -216,9 +232,11 @@ func (l *Conn) processMessages() { } chanResult := l.chanResults[ message_packet.MessageID ] if chanResult == nil { - fmt.Printf( "Unexpected Message Result: %d", message_id ) + fmt.Printf( "Unexpected Message Result: %d\n", message_id ) + ber.PrintPacket( message_packet.Packet ) } else { - chanResult <- message_packet.Packet + go func() { chanResult <- message_packet.Packet }() + // chanResult <- message_packet.Packet } case MessageFinish: // Remove from message list @@ -232,6 +250,7 @@ func (l *Conn) processMessages() { } func (l *Conn) closeAllChannels() { +fmt.Printf( "closeAllChannels\n" ) for MessageID, Channel := range l.chanResults { if l.Debug { fmt.Printf( "Closing channel for MessageID %d\n", MessageID ); @@ -241,30 +260,42 @@ func (l *Conn) closeAllChannels() { } close( l.chanMessageID ) l.chanMessageID = nil + + close( l.chanProcessMessage ) + l.chanProcessMessage = nil } func (l *Conn) finishMessage( MessageID uint64 ) { message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID } - if l.chanProcessMessage != nil { - l.chanProcessMessage <- message_packet - } + l.sendProcessMessage( message_packet ) } func (l *Conn) reader() { + defer l.Close() for { p, err := ber.ReadPacket( l.conn ) if err != nil { if l.Debug { fmt.Printf( "ldap.reader: %s\n", err.String() ) } - break + return } + addLDAPDescriptions( p ) + message_id := p.Children[ 0 ].Value.(uint64) message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p } - l.chanProcessMessage <- message_packet + if l.chanProcessMessage != nil { + l.chanProcessMessage <- message_packet + } else { + fmt.Printf( "ldap.reader: Cannot return message\n" ) + return + } } - - l.Close() } +func (l *Conn) sendProcessMessage( message *messagePacket ) { + if l.chanProcessMessage != nil { + go func() { l.chanProcessMessage <- message }() + } +} -- cgit v1.2.3