summaryrefslogblamecommitdiffstats
path: root/src/mbedTLS++/BlockingSslClientSocket.cpp
blob: 6e6410879d0deaad3711d714ff7654d04e89d502 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
 










                                                                                                                      











































                                                                                      
                                                                                  




                                                          
                                                  




                                          
                                                                                  


                                          
 
       
 












                                                                                   
                                                          

                            
















                                                                                                                               
 
                                         











                                                                                  
         

                             
 
                              









                                                 

                     
                                                                                                      

                             
 
                                                                                               
                                        
         
                                                              
         
 


                                
                                                                                                 

                             
 






                    
                                                                              
 

                                                                 
                                                                  
                                        

                           

                                                                                                                                          

                  
 












                                                                                       
         
                                                                                                          
         
 
                                       







                                                                           




                                                     
 
                                                         
                                                              


                                     
                                                             

                            

                                                                                                                       
                                                                                                                          




                                     
                                                             











                                            

                                                                       
                                                                                                                


                                                      
                                                                                                                 







                   
                                           





                                  
 
                            
                              








                                                                                       







                                                                                           









                                                                             
         
                                                   
         





                                                                           







                                                                                              


                                                                                                                 
                                                   

                                                
         

                                                                                                                    
         





                                            
 











                                                                         
                                                                                      





















                                                                                 
                                     







                                                   
                              
                         
                      




 

// BlockingSslClientSocket.cpp

// Implements the cBlockingSslClientSocket class representing a blocking TCP socket with client SSL encryption over it

#include "Globals.h"
#include "BlockingSslClientSocket.h"





////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocketConnectCallbacks:

class cBlockingSslClientSocketConnectCallbacks:
	public cNetwork::cConnectCallbacks
{
	/** The socket object that is using this instance of the callbacks. */
	cBlockingSslClientSocket & m_Socket;

	virtual void OnConnected(cTCPLink & a_Link) override
	{
		m_Socket.OnConnected();
	}

	virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
	{
		m_Socket.OnConnectError(a_ErrorMsg);
	}

public:
	cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket):
		m_Socket(a_Socket)
	{
	}
};





////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocketLinkCallbacks:

class cBlockingSslClientSocketLinkCallbacks:
	public cTCPLink::cCallbacks
{
	cBlockingSslClientSocket & m_Socket;

	virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
	{
		m_Socket.SetLink(a_Link);
	}


	virtual void OnReceivedData(const char * a_Data, size_t a_Length) override
	{
		m_Socket.OnReceivedData(a_Data, a_Length);
	}


	virtual void OnRemoteClosed(void) override
	{
		m_Socket.OnDisconnected();
	}


	virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
	{
		m_Socket.OnDisconnected();
	}

public:

	cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket):
		m_Socket(a_Socket)
	{
	}
};





////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocket:

cBlockingSslClientSocket::cBlockingSslClientSocket(void) :
	m_Ssl(*this),
	m_IsConnected(false)
{
	// Nothing needed yet
}





bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Port)
{
	// If already connected, report an error:
	if (m_IsConnected)
	{
		// TODO: Handle this better - if connected to the same server and port, and the socket is alive, return success
		m_LastErrorText = "Already connected";
		return false;
	}

	// Connect the underlying socket:
	m_ServerName = a_ServerName;
	if (!cNetwork::Connect(a_ServerName, a_Port,
		std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
		std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
	)
	{
		return false;
	}

	// Wait for the connection to succeed or fail:
	m_Event.Wait();
	if (!m_IsConnected)
	{
		return false;
	}

	// Initialize the SSL:
	int ret = 0;
	if (m_Config != nullptr)
	{
		ret = m_Ssl.Initialize(m_Config);
	}
	else
	{
		ret = m_Ssl.Initialize(true);
	}

	if (ret != 0)
	{
		m_LastErrorText = fmt::format(FMT_STRING("SSL initialization failed: -0x{:x}"), -ret);
		return false;
	}

	// If we have been assigned a trusted CA root cert store, push it into the SSL context:
	if (!m_ExpectedPeerName.empty())
	{
		m_Ssl.SetExpectedPeerName(m_ExpectedPeerName);
	}

	ret = m_Ssl.Handshake();
	if (ret != 0)
	{
		m_LastErrorText = fmt::format(FMT_STRING("SSL handshake failed: -0x{:x}"), -ret);
		return false;
	}

	return true;
}





void cBlockingSslClientSocket::SetExpectedPeerName(AString a_ExpectedPeerName)
{
	ASSERT(!m_IsConnected);  // Must be called before connect

	// Warn if used multiple times, but don't signal an error:
	if (!m_ExpectedPeerName.empty())
	{
		LOGWARNING(
			"SSL: Trying to set multiple expected peer names, only the last one will be used. %s overwriting the previous %s",
			a_ExpectedPeerName, m_ExpectedPeerName
		);
	}

	m_ExpectedPeerName = std::move(a_ExpectedPeerName);
}





void cBlockingSslClientSocket::SetSslConfig(std::shared_ptr<const cSslConfig> a_Config)
{
	ASSERT(!m_IsConnected);  // Must be called before connect

	// Warn if used multiple times, but don't signal an error:
	if (m_Config != nullptr)
	{
		LOGWARNING("SSL: Trying to set multiple configurations, only the last one will be used.");
	}

	m_Config = std::move(a_Config);
}





bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
{
	if (!m_IsConnected)
	{
		m_LastErrorText = "Socket is closed";
		return false;
	}

	// Keep sending the data until all of it is sent:
	const char * Data = static_cast<const char *>(a_Data);
	size_t NumBytes = a_NumBytes;
	for (;;)
	{
		int res = m_Ssl.WritePlain(Data, a_NumBytes);
		if (res < 0)
		{
			ASSERT(res != MBEDTLS_ERR_SSL_WANT_READ);   // This should never happen with callback-based SSL
			ASSERT(res != MBEDTLS_ERR_SSL_WANT_WRITE);  // This should never happen with callback-based SSL
			m_LastErrorText = fmt::format(FMT_STRING("Data cannot be written to SSL context: -0x{:x}"), -res);
			return false;
		}
		else
		{
			Data += res;
			NumBytes -= static_cast<size_t>(res);
			if (NumBytes == 0)
			{
				return true;
			}
		}
	}
}





int cBlockingSslClientSocket::Receive(void * a_Data, size_t a_MaxBytes)
{
	// Even if m_IsConnected is false (socket disconnected), the SSL context may have more data in the queue
	int res = m_Ssl.ReadPlain(a_Data, a_MaxBytes);
	if (res < 0)
	{
		m_LastErrorText = fmt::format(FMT_STRING("Data cannot be read from SSL context: -0x{:x}"), -res);
	}
	return res;
}





void cBlockingSslClientSocket::Disconnect()
{
	// Ignore if not connected
	if (!m_IsConnected)
	{
		return;
	}

	m_Ssl.NotifyClose();
	m_IsConnected = false;

	// Grab a copy of the socket so that we know it doesn't change under our hands:
	auto socket = m_Socket;
	if (socket != nullptr)
	{
		socket->Close();
	}

	m_Socket.reset();
}





int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
{
	// Wait for any incoming data, if there is none:
	cCSLock Lock(m_CSIncomingData);
	while (m_IsConnected && m_IncomingData.empty())
	{
		cCSUnlock Unlock(Lock);
		m_Event.Wait();
	}

	// If we got disconnected, report an error after processing all data:
	if (!m_IsConnected && m_IncomingData.empty())
	{
		return MBEDTLS_ERR_NET_RECV_FAILED;
	}

	// Copy the data from the incoming buffer into the specified space:
	size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
	memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
	m_IncomingData.erase(0, NumToCopy);
	return static_cast<int>(NumToCopy);
}





int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
{
	cTCPLinkPtr Socket(m_Socket);  // Make a copy so that multiple threads don't race on deleting the socket.
	if (Socket == nullptr)
	{
		return MBEDTLS_ERR_NET_SEND_FAILED;
	}
	if (!Socket->Send(a_Buffer, a_NumBytes))
	{
		// mbedTLS's net routines distinguish between connection reset and general failure, we don't need to
		return MBEDTLS_ERR_NET_SEND_FAILED;
	}
	return static_cast<int>(a_NumBytes);
}





void cBlockingSslClientSocket::OnConnected(void)
{
	m_IsConnected = true;
	m_Event.Set();
}





void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg)
{
	LOG("Cannot connect to %s: \"%s\"", m_ServerName.c_str(), a_ErrorMsg.c_str());
	m_Event.Set();
}





void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
{
	{
		cCSLock Lock(m_CSIncomingData);
		m_IncomingData.append(a_Data, a_Size);
	}
	m_Event.Set();
}





void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link)
{
	m_Socket = std::move(a_Link);
}





void cBlockingSslClientSocket::OnDisconnected(void)
{
	m_IsConnected = false;
	m_Socket.reset();
	m_Event.Set();
}