diff options
author | LC <mathew1800@gmail.com> | 2020-10-29 06:54:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-29 06:54:45 +0100 |
commit | 1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e (patch) | |
tree | 626978551472dce8381730d277f930bc0bc5be07 | |
parent | Merge pull request #4835 from lat9nq/rng-default-time (diff) | |
parent | web_service: follow-up fix to #4842 ... (diff) | |
download | yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar.gz yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar.bz2 yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar.lz yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar.xz yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.tar.zst yuzu-1a6b1bf1d7d54218d933722a0bfc3d4e6762ac0e.zip |
-rw-r--r-- | externals/httplib/README.md | 2 | ||||
-rw-r--r-- | externals/httplib/httplib.h | 651 | ||||
-rw-r--r-- | src/web_service/web_backend.cpp | 5 |
3 files changed, 441 insertions, 217 deletions
diff --git a/externals/httplib/README.md b/externals/httplib/README.md index 73037d297..1940e446c 100644 --- a/externals/httplib/README.md +++ b/externals/httplib/README.md @@ -1,4 +1,4 @@ -From https://github.com/yhirose/cpp-httplib/tree/fce8e6fefdab4ad48bc5b25c98e5ebfda4f3cf53 +From https://github.com/yhirose/cpp-httplib/tree/ff5677ad197947177c158fe857caff4f0e242045 with https://github.com/yhirose/cpp-httplib/pull/701 MIT License diff --git a/externals/httplib/httplib.h b/externals/httplib/httplib.h index 5139b7f05..8982054e2 100644 --- a/externals/httplib/httplib.h +++ b/externals/httplib/httplib.h @@ -173,6 +173,7 @@ using socket_t = int; #define INVALID_SOCKET (-1) #endif //_WIN32 +#include <algorithm> #include <array> #include <atomic> #include <cassert> @@ -237,6 +238,27 @@ namespace httplib { namespace detail { +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template <class T, class... Args> +typename std::enable_if<!std::is_array<T>::value, std::unique_ptr<T>>::type +make_unique(Args &&... args) { + return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); +} + +template <class T> +typename std::enable_if<std::is_array<T>::value, std::unique_ptr<T>>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent<T>::type RT; + return std::unique_ptr<T>(new RT[n]); +} + struct ci { bool operator()(const std::string &s1, const std::string &s2) const { return std::lexicographical_compare( @@ -304,6 +326,10 @@ using ContentProvider = using ContentProviderWithoutLength = std::function<bool(size_t offset, DataSink &sink)>; +using ContentReceiverWithProgress = + std::function<bool(const char *data, size_t data_length, uint64_t offset, + uint64_t total_length)>; + using ContentReceiver = std::function<bool(const char *data, size_t data_length)>; @@ -317,14 +343,17 @@ public: ContentReceiver receiver)>; ContentReader(Reader reader, MultipartReader multipart_reader) - : reader_(reader), multipart_reader_(multipart_reader) {} + : reader_(std::move(reader)), + multipart_reader_(std::move(multipart_reader)) {} bool operator()(MultipartContentHeader header, ContentReceiver receiver) const { - return multipart_reader_(header, receiver); + return multipart_reader_(std::move(header), std::move(receiver)); } - bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } Reader reader_; MultipartReader multipart_reader_; @@ -353,7 +382,7 @@ struct Request { // for client size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; ResponseHandler response_handler; - ContentReceiver content_receiver; + ContentReceiverWithProgress content_receiver; size_t content_length = 0; ContentProvider content_provider; Progress progress; @@ -475,7 +504,7 @@ public: void enqueue(std::function<void()> fn) override { std::unique_lock<std::mutex> lock(mutex_); - jobs_.push_back(fn); + jobs_.push_back(std::move(fn)); cond_.notify_one(); } @@ -664,8 +693,8 @@ private: ContentReceiver multipart_receiver); virtual bool process_and_close_socket(socket_t sock); - - struct MountPointEntry { + + struct MountPointEntry { std::string mount_point; std::string base_dir; Headers headers; @@ -704,23 +733,27 @@ enum Error { Canceled, SSLConnection, SSLLoadingCerts, - SSLServerVerification + SSLServerVerification, + UnsupportedMultipartBoundaryChars }; class Result { public: - Result(const std::shared_ptr<Response> &res, Error err) - : res_(res), err_(err) {} + Result(std::unique_ptr<Response> res, Error err) + : res_(std::move(res)), err_(err) {} operator bool() const { return res_ != nullptr; } bool operator==(std::nullptr_t) const { return res_ == nullptr; } bool operator!=(std::nullptr_t) const { return res_ != nullptr; } const Response &value() const { return *res_; } + Response &value() { return *res_; } const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } Error error() const { return err_; } private: - std::shared_ptr<Response> res_; + std::unique_ptr<Response> res_; Error err_; }; @@ -777,6 +810,8 @@ public: Result Post(const char *path, const MultipartFormDataItems &items); Result Post(const char *path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); Result Put(const char *path); Result Put(const char *path, const std::string &body, @@ -863,7 +898,21 @@ protected: }; virtual bool create_and_connect_socket(Socket &socket); - virtual void close_socket(Socket &socket, bool process_socket_ret); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket); + void close_socket(Socket &socket); + + // Similar to shutdown_ssl and close_socket, this should NOT be called + // concurrently with a DIFFERENT thread sending requests from the socket + void lock_socket_and_shutdown_and_close(); bool process_request(Stream &strm, const Request &req, Response &res, bool close_connection); @@ -873,7 +922,7 @@ protected: void copy_settings(const ClientImpl &rhs); // Error state - mutable Error error_ = Error::Success; + mutable std::atomic<Error> error_; // Socket endoint information const std::string host_; @@ -885,6 +934,11 @@ protected: mutable std::mutex socket_mutex_; std::recursive_mutex request_mutex_; + // These are all protected under socket_mutex + int socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + // Default headers Headers default_headers_; @@ -942,13 +996,13 @@ private: bool redirect(const Request &req, Response &res); bool handle_request(Stream &strm, const Request &req, Response &res, bool close_connection); - void stop_core(); - std::shared_ptr<Response> send_with_content_provider( + std::unique_ptr<Response> send_with_content_provider( const char *method, const char *path, const Headers &headers, const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type); - virtual bool process_socket(Socket &socket, + // socket is const because this function is called when socket_mutex_ is not locked + virtual bool process_socket(const Socket &socket, std::function<bool(Stream &strm)> callback); virtual bool is_ssl() const; }; @@ -1012,6 +1066,8 @@ public: Result Post(const char *path, const MultipartFormDataItems &items); Result Post(const char *path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); Result Put(const char *path); Result Put(const char *path, const std::string &body, const char *content_type); @@ -1098,7 +1154,7 @@ public: #endif private: - std::shared_ptr<ClientImpl> cli_; + std::unique_ptr<ClientImpl> cli_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool is_ssl_ = false; @@ -1154,9 +1210,9 @@ public: private: bool create_and_connect_socket(Socket &socket) override; - void close_socket(Socket &socket, bool process_socket_ret) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; - bool process_socket(Socket &socket, + bool process_socket(const Socket &socket, std::function<bool(Stream &strm)> callback) override; bool is_ssl() const override; @@ -1942,9 +1998,9 @@ inline socket_t create_client_socket(const char *host, int port, bool tcp_nodelay, SocketOptions socket_options, time_t timeout_sec, time_t timeout_usec, - const std::string &intf, Error &error) { + const std::string &intf, std::atomic<Error> &error) { auto sock = create_socket( - host, port, 0, tcp_nodelay, socket_options, + host, port, 0, tcp_nodelay, std::move(socket_options), [&](socket_t sock, struct addrinfo &ai) -> bool { if (!intf.empty()) { #ifdef USE_IF2IP @@ -2478,7 +2534,8 @@ inline bool read_headers(Stream &strm, Headers &headers) { } inline bool read_content_with_length(Stream &strm, uint64_t len, - Progress progress, ContentReceiver out) { + Progress progress, + ContentReceiverWithProgress out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; uint64_t r = 0; @@ -2487,8 +2544,7 @@ inline bool read_content_with_length(Stream &strm, uint64_t len, auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); if (n <= 0) { return false; } - if (!out(buf, static_cast<size_t>(n))) { return false; } - + if (!out(buf, static_cast<size_t>(n), r, len)) { return false; } r += static_cast<uint64_t>(n); if (progress) { @@ -2510,8 +2566,10 @@ inline void skip_content_with_length(Stream &strm, uint64_t len) { } } -inline bool read_content_without_length(Stream &strm, ContentReceiver out) { +inline bool read_content_without_length(Stream &strm, + ContentReceiverWithProgress out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); if (n < 0) { @@ -2519,13 +2577,16 @@ inline bool read_content_without_length(Stream &strm, ContentReceiver out) { } else if (n == 0) { return true; } - if (!out(buf, static_cast<size_t>(n))) { return false; } + + if (!out(buf, static_cast<size_t>(n), r, 0)) { return false; } + r += static_cast<uint64_t>(n); } return true; } -inline bool read_content_chunked(Stream &strm, ContentReceiver out) { +inline bool read_content_chunked(Stream &strm, + ContentReceiverWithProgress out) { const auto bufsiz = 16; char buf[bufsiz]; @@ -2570,23 +2631,24 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) { } template <typename T, typename U> -bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, bool decompress, U callback) { if (decompress) { std::string encoding = x.get_header_value("Content-Encoding"); - std::shared_ptr<decompressor> decompressor; + std::unique_ptr<decompressor> decompressor; if (encoding.find("gzip") != std::string::npos || encoding.find("deflate") != std::string::npos) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = std::make_shared<gzip_decompressor>(); + decompressor = detail::make_unique<gzip_decompressor>(); #else status = 415; return false; #endif } else if (encoding.find("br") != std::string::npos) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = std::make_shared<brotli_decompressor>(); + decompressor = detail::make_unique<brotli_decompressor>(); #else status = 415; return false; @@ -2595,12 +2657,14 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, if (decompressor) { if (decompressor->is_valid()) { - ContentReceiver out = [&](const char *buf, size_t n) { - return decompressor->decompress( - buf, n, - [&](const char *buf, size_t n) { return receiver(buf, n); }); + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf, size_t n) { + return receiver(buf, n, off, len); + }); }; - return callback(out); + return callback(std::move(out)); } else { status = 500; return false; @@ -2608,18 +2672,20 @@ bool prepare_content_receiver(T &x, int &status, ContentReceiver receiver, } } - ContentReceiver out = [&](const char *buf, size_t n) { - return receiver(buf, n); + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, + uint64_t len) { + return receiver(buf, n, off, len); }; - return callback(out); + return callback(std::move(out)); } template <typename T> bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, - Progress progress, ContentReceiver receiver, + Progress progress, ContentReceiverWithProgress receiver, bool decompress) { return prepare_content_receiver( - x, status, receiver, decompress, [&](const ContentReceiver &out) { + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { auto ret = true; auto exceed_payload_max_length = false; @@ -2634,7 +2700,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, skip_content_with_length(strm, len); ret = false; } else if (len > 0) { - ret = read_content_with_length(strm, len, progress, out); + ret = read_content_with_length(strm, len, std::move(progress), out); } } @@ -2875,7 +2941,7 @@ inline bool parse_multipart_boundary(const std::string &content_type, return !boundary.empty(); } -inline bool parse_range_header(const std::string &s, Ranges &ranges) { +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); std::smatch m; if (std::regex_match(s, m, re_first_range)) { @@ -2907,7 +2973,7 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) { return all_valid_ranges; } return false; -} +} catch (...) { return false; } class MultipartFormDataParser { public: @@ -2918,7 +2984,8 @@ public: bool is_valid() const { return is_valid_; } template <typename T, typename U> - bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + bool parse(const char *buf, size_t n, const T &content_callback, + const U &header_callback) { static const std::regex re_content_disposition( "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" @@ -3090,8 +3157,13 @@ inline std::string make_multipart_data_boundary() { static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. std::random_device seed_gen; - std::mt19937 engine(seed_gen()); + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + std::mt19937 engine(seed_sequence); std::string result = "--cpp-httplib-multipart-data-"; @@ -3114,7 +3186,7 @@ get_range_offset_and_length(const Request &req, size_t content_length, auto slen = static_cast<ssize_t>(content_length); if (r.first == -1) { - r.first = slen - r.second; + r.first = (std::max)(static_cast<ssize_t>(0), slen - r.second); r.second = slen - 1; } @@ -3451,7 +3523,7 @@ inline std::pair<std::string, std::string> make_range_header(Ranges ranges) { if (r.second != -1) { field += std::to_string(r.second); } i++; } - return std::make_pair("Range", field); + return std::make_pair("Range", std::move(field)); } inline std::pair<std::string, std::string> @@ -3460,7 +3532,7 @@ make_basic_authentication_header(const std::string &username, bool is_proxy = false) { auto field = "Basic " + detail::base64_encode(username + ":" + password); auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + return std::make_pair(key, std::move(field)); } inline std::pair<std::string, std::string> @@ -3468,7 +3540,7 @@ make_bearer_token_authentication_header(const std::string &token, bool is_proxy = false) { auto field = "Bearer " + token; auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); + return std::make_pair(key, std::move(field)); } // Request implementation @@ -3761,60 +3833,66 @@ inline Server::Server() inline Server::~Server() {} inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + get_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + post_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Post(const char *pattern, HandlerWithContentReader handler) { post_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + put_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Put(const char *pattern, HandlerWithContentReader handler) { put_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + patch_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Patch(const char *pattern, HandlerWithContentReader handler) { patch_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + delete_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Delete(const char *pattern, HandlerWithContentReader handler) { delete_handlers_for_content_reader_.push_back( - std::make_pair(std::regex(pattern), handler)); + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + options_handlers_.push_back( + std::make_pair(std::regex(pattern), std::move(handler))); return *this; } @@ -3860,7 +3938,7 @@ inline void Server::set_error_handler(Handler handler) { inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } inline void Server::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = std::move(socket_options); } inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } @@ -4045,16 +4123,16 @@ inline bool Server::write_response(Stream &strm, bool close_connection, } if (type != detail::EncodingType::None) { - std::shared_ptr<detail::compressor> compressor; + std::unique_ptr<detail::compressor> compressor; if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared<detail::gzip_compressor>(); + compressor = detail::make_unique<detail::gzip_compressor>(); res.set_header("Content-Encoding", "gzip"); #endif } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared<detail::brotli_compressor>(); + compressor = detail::make_unique<detail::brotli_compressor>(); res.set_header("Content-Encoding", "brotli"); #endif } @@ -4136,17 +4214,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req, if (res.is_chunked_content_provider) { auto type = detail::encoding_type(req, res); - std::shared_ptr<detail::compressor> compressor; + std::unique_ptr<detail::compressor> compressor; if (type == detail::EncodingType::Gzip) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = std::make_shared<detail::gzip_compressor>(); + compressor = detail::make_unique<detail::gzip_compressor>(); #endif } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = std::make_shared<detail::brotli_compressor>(); + compressor = detail::make_unique<detail::brotli_compressor>(); #endif } else { - compressor = std::make_shared<detail::nocompressor>(); + compressor = detail::make_unique<detail::nocompressor>(); } assert(compressor != nullptr); @@ -4198,8 +4276,9 @@ inline bool Server::read_content_with_content_receiver( Stream &strm, Request &req, Response &res, ContentReceiver receiver, MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) { - return read_content_core(strm, req, res, receiver, multipart_header, - multipart_receiver); + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); } inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, @@ -4207,7 +4286,7 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, MultipartContentHeader mulitpart_header, ContentReceiver multipart_receiver) { detail::MultipartFormDataParser multipart_form_data_parser; - ContentReceiver out; + ContentReceiverWithProgress out; if (req.is_multipart_form_data()) { const auto &content_type = req.get_header_value("Content-Type"); @@ -4218,7 +4297,7 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, } multipart_form_data_parser.set_boundary(std::move(boundary)); - out = [&](const char *buf, size_t n) { + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { /* For debug size_t pos = 0; while (pos < n) { @@ -4234,7 +4313,8 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, mulitpart_header); }; } else { - out = receiver; + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { return receiver(buf, n); }; } if (req.method == "DELETE" && !req.has_header("Content-Length")) { @@ -4271,7 +4351,7 @@ inline bool Server::handle_file_request(Request &req, Response &res, auto type = detail::find_content_type(path, file_extension_and_mimetype_map_); if (type) { res.set_header("Content-Type", type); } - for (const auto& kv : entry.headers) { + for (const auto &kv : entry.headers) { res.set_header(kv.first.c_str(), kv.second); } res.status = 200; @@ -4290,7 +4370,7 @@ inline socket_t Server::create_server_socket(const char *host, int port, int socket_flags, SocketOptions socket_options) const { return detail::create_socket( - host, port, socket_flags, tcp_nodelay_, socket_options, + host, port, socket_flags, tcp_nodelay_, std::move(socket_options), [](socket_t sock, struct addrinfo &ai) -> bool { if (::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) { return false; @@ -4392,32 +4472,37 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) { { ContentReader reader( [&](ContentReceiver receiver) { - return read_content_with_content_receiver(strm, req, res, receiver, - nullptr, nullptr); + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); }, [&](MultipartContentHeader header, ContentReceiver receiver) { return read_content_with_content_receiver(strm, req, res, nullptr, - header, receiver); + std::move(header), + std::move(receiver)); }); if (req.method == "POST") { if (dispatch_request_for_content_reader( - req, res, reader, post_handlers_for_content_reader_)) { + req, res, std::move(reader), + post_handlers_for_content_reader_)) { return true; } } else if (req.method == "PUT") { if (dispatch_request_for_content_reader( - req, res, reader, put_handlers_for_content_reader_)) { + req, res, std::move(reader), + put_handlers_for_content_reader_)) { return true; } } else if (req.method == "PATCH") { if (dispatch_request_for_content_reader( - req, res, reader, patch_handlers_for_content_reader_)) { + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { return true; } } else if (req.method == "DELETE") { if (dispatch_request_for_content_reader( - req, res, reader, delete_handlers_for_content_reader_)) { + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { return true; } } @@ -4448,7 +4533,6 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) { inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { - try { for (const auto &x : handlers) { const auto &pattern = x.first; @@ -4531,7 +4615,8 @@ Server::process_request(Stream &strm, bool close_connection, if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { - // TODO: error + res.status = 416; + return write_response(strm, close_connection, req, res); } } @@ -4588,11 +4673,11 @@ inline ClientImpl::ClientImpl(const std::string &host, int port) inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : host_(host), port_(port), + : error_(Error::Success), host_(host), port_(port), host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} -inline ClientImpl::~ClientImpl() { stop_core(); } +inline ClientImpl::~ClientImpl() { lock_socket_and_shutdown_and_close(); } inline bool ClientImpl::is_valid() const { return true; } @@ -4653,15 +4738,47 @@ inline bool ClientImpl::create_and_connect_socket(Socket &socket) { return true; } -inline void ClientImpl::close_socket(Socket &socket, - bool /*process_socket_ret*/) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; +inline void ClientImpl::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + (void)socket; + (void)shutdown_gracefully; + //If there are any requests in flight from threads other than us, then it's + //a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) { + if (socket.sock == INVALID_SOCKET) + return; + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + // It is also a bug if this happens while SSL is still active #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - socket_.ssl = nullptr; + assert(socket.ssl == nullptr); #endif + if (socket.sock == INVALID_SOCKET) + return; + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; } +inline void ClientImpl::lock_socket_and_shutdown_and_close() { + std::lock_guard<std::mutex> guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { std::array<char, 2048> buf; @@ -4696,11 +4813,23 @@ inline bool ClientImpl::send(const Request &req, Response &res) { { std::lock_guard<std::mutex> guard(socket_mutex_); + // Set this to false immediately - if it ever gets set to true by the end of the + // request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; auto is_alive = false; if (socket_.is_open()) { is_alive = detail::select_write(socket_.sock, 0, 0) > 0; - if (!is_alive) { close_socket(socket_, false); } + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems like + // the other side has already closed the connection + // Also, there cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } } if (!is_alive) { @@ -4721,15 +4850,38 @@ inline bool ClientImpl::send(const Request &req, Response &res) { } #endif } + + // Mark the current socket as being in use so that it cannot be closed by anyone + // else while this request is ongoing, even though we will be releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); } auto close_connection = !keep_alive_; - auto ret = process_socket(socket_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection); }); - if (close_connection || !ret) { stop_core(); } + //Briefly lock mutex in order to mark that a request is no longer ongoing + { + std::lock_guard<std::mutex> guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || + close_connection || + !ret ) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + } if (!ret) { if (error_ == Error::Success) { error_ = Error::Unknown; } @@ -4973,7 +5125,7 @@ inline bool ClientImpl::write_request(Stream &strm, const Request &req, return true; } -inline std::shared_ptr<Response> ClientImpl::send_with_content_provider( +inline std::unique_ptr<Response> ClientImpl::send_with_content_provider( const char *method, const char *path, const Headers &headers, const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type) { @@ -5036,15 +5188,15 @@ inline std::shared_ptr<Response> ClientImpl::send_with_content_provider( { if (content_provider) { req.content_length = content_length; - req.content_provider = content_provider; + req.content_provider = std::move(content_provider); } else { req.body = body; } } - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); - return send(req, *res) ? res : nullptr; + return send(req, *res) ? std::move(res) : nullptr; } inline bool ClientImpl::process_request(Stream &strm, const Request &req, @@ -5070,16 +5222,21 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, if (req.method != "HEAD" && req.method != "CONNECT") { auto out = req.content_receiver - ? static_cast<ContentReceiver>([&](const char *buf, size_t n) { - auto ret = req.content_receiver(buf, n); - if (!ret) { error_ = Error::Canceled; } - return ret; - }) - : static_cast<ContentReceiver>([&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { return false; } - res.body.append(buf, n); - return true; - }); + ? static_cast<ContentReceiverWithProgress>( + [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error_ = Error::Canceled; } + return ret; + }) + : static_cast<ContentReceiverWithProgress>( + [&](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); auto progress = [&](uint64_t current, uint64_t total) { if (!req.progress) { return true; } @@ -5090,7 +5247,8 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, int dummy_status; if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(), - dummy_status, progress, out, decompress_)) { + dummy_status, std::move(progress), std::move(out), + decompress_)) { if (error_ != Error::Canceled) { error_ = Error::Read; } return false; } @@ -5098,7 +5256,16 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, if (res.get_header_value("Connection") == "close" || (res.version == "HTTP/1.0" && res.reason != "Connection established")) { - stop_core(); + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. Maybe a code refactor (such as moving this out to + // the send function and getting rid of the recursiveness of the mutex) + // could make this more obvious. + + // This is safe to call because process_request is only called by handle_request + // which is only called by send, which locks the request mutex during the process. + // It would be a bug to call it from a different thread since it's a thread-safety + // issue to do these things to the socket if another thread is using the socket. + lock_socket_and_shutdown_and_close(); } // Log @@ -5108,11 +5275,11 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, } inline bool -ClientImpl::process_socket(Socket &socket, +ClientImpl::process_socket(const Socket &socket, std::function<bool(Stream &strm)> callback) { - return detail::process_client_socket(socket.sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, callback); + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -5138,9 +5305,9 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers, req.headers.insert(headers.begin(), headers.end()); req.progress = std::move(progress); - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + return Result{ret ? std::move(res) : nullptr, get_last_error()}; } inline Result ClientImpl::Get(const char *path, @@ -5170,23 +5337,23 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers, inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - nullptr); + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver) { - return Get(path, headers, std::move(response_handler), content_receiver, - nullptr); + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), nullptr); } inline Result ClientImpl::Get(const char *path, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return Get(path, Headers(), std::move(response_handler), content_receiver, - progress); + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); } inline Result ClientImpl::Get(const char *path, const Headers &headers, @@ -5199,12 +5366,16 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers, req.headers = default_headers_; req.headers.insert(headers.begin(), headers.end()); req.response_handler = std::move(response_handler); - req.content_receiver = std::move(content_receiver); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + uint64_t /*offset*/, uint64_t /*total_length*/) { + return content_receiver(data, data_length); + }; req.progress = std::move(progress); - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + return Result{ret ? std::move(res) : nullptr, get_last_error()}; } inline Result ClientImpl::Head(const char *path) { @@ -5218,9 +5389,9 @@ inline Result ClientImpl::Head(const char *path, const Headers &headers) { req.headers.insert(headers.begin(), headers.end()); req.path = path; - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + return Result{ret ? std::move(res) : nullptr, get_last_error()}; } inline Result ClientImpl::Post(const char *path) { @@ -5237,7 +5408,7 @@ inline Result ClientImpl::Post(const char *path, const Headers &headers, const char *content_type) { auto ret = send_with_content_provider("POST", path, headers, body, 0, nullptr, content_type); - return Result{ret, get_last_error()}; + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Params ¶ms) { @@ -5247,17 +5418,18 @@ inline Result ClientImpl::Post(const char *path, const Params ¶ms) { inline Result ClientImpl::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Post(path, Headers(), content_length, content_provider, content_type); + return Post(path, Headers(), content_length, std::move(content_provider), + content_type); } inline Result ClientImpl::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("POST", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider( + "POST", path, headers, std::string(), content_length, + std::move(content_provider), content_type); + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Post(const char *path, const Headers &headers, @@ -5273,7 +5445,18 @@ inline Result ClientImpl::Post(const char *path, inline Result ClientImpl::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { - auto boundary = detail::make_multipart_data_boundary(); + return Post(path, headers, items, detail::make_multipart_data_boundary()); +} +inline Result ClientImpl::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + for (size_t i = 0; i < boundary.size(); i++) { + char c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + error_ = Error::UnsupportedMultipartBoundaryChars; + return Result{nullptr, error_}; + } + } std::string body; @@ -5311,23 +5494,24 @@ inline Result ClientImpl::Put(const char *path, const Headers &headers, const char *content_type) { auto ret = send_with_content_provider("PUT", path, headers, body, 0, nullptr, content_type); - return Result{ret, get_last_error()}; + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Put(path, Headers(), content_length, content_provider, content_type); + return Put(path, Headers(), content_length, std::move(content_provider), + content_type); } inline Result ClientImpl::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PUT", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider( + "PUT", path, headers, std::string(), content_length, + std::move(content_provider), content_type); + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Put(const char *path, const Params ¶ms) { @@ -5350,23 +5534,24 @@ inline Result ClientImpl::Patch(const char *path, const Headers &headers, const char *content_type) { auto ret = send_with_content_provider("PATCH", path, headers, body, 0, nullptr, content_type); - return Result{ret, get_last_error()}; + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return Patch(path, Headers(), content_length, content_provider, content_type); + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type); } inline Result ClientImpl::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - auto ret = send_with_content_provider("PATCH", path, headers, std::string(), - content_length, content_provider, - content_type); - return Result{ret, get_last_error()}; + auto ret = send_with_content_provider( + "PATCH", path, headers, std::string(), content_length, + std::move(content_provider), content_type); + return Result{std::move(ret), get_last_error()}; } inline Result ClientImpl::Delete(const char *path) { @@ -5394,9 +5579,9 @@ inline Result ClientImpl::Delete(const char *path, const Headers &headers, if (content_type) { req.headers.emplace("Content-Type", content_type); } req.body = body; - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + return Result{ret ? std::move(res) : nullptr, get_last_error()}; } inline Result ClientImpl::Options(const char *path) { @@ -5410,9 +5595,9 @@ inline Result ClientImpl::Options(const char *path, const Headers &headers) { req.headers.insert(headers.begin(), headers.end()); req.path = path; - auto res = std::make_shared<Response>(); + auto res = detail::make_unique<Response>(); auto ret = send(req, *res); - return Result{ret ? res : nullptr, get_last_error()}; + return Result{ret ? std::move(res) : nullptr, get_last_error()}; } inline size_t ClientImpl::is_socket_open() const { @@ -5421,18 +5606,27 @@ inline size_t ClientImpl::is_socket_open() const { } inline void ClientImpl::stop() { - stop_core(); - error_ = Error::Canceled; -} - -inline void ClientImpl::stop_core() { std::lock_guard<std::mutex> guard(socket_mutex_); - if (socket_.is_open()) { - detail::shutdown_socket(socket_.sock); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - close_socket(socket_, true); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + // There is no guarantee that this doesn't get overwritten later, but set it so that + // there is a good chance that any threads stopping as a result pick up this error. + error_ = Error::Canceled; + + // If there is anything ongoing right now, the ONLY thread-safe thing we can do + // is to shutdown_socket, so that threads using this socket suddenly discover + // they can't read/write any more and error out. + // Everything else (closing the socket, shutting ssl down) is unsafe because these + // actions are not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + // Aside from that, we set a flag for the socket to be closed when we're done. + socket_should_be_closed_when_request_is_done_ = true; + return; } + + //Otherwise, sitll holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); } inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { @@ -5479,7 +5673,7 @@ inline void ClientImpl::set_default_headers(Headers headers) { inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } inline void ClientImpl::set_socket_options(SocketOptions socket_options) { - socket_options_ = socket_options; + socket_options_ = std::move(socket_options); } inline void ClientImpl::set_compress(bool on) { compress_ = on; } @@ -5554,9 +5748,12 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, } inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, - bool process_socket_ret) { - if (process_socket_ret) { - SSL_shutdown(ssl); // shutdown only if not already closed by remote + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a best-efforts. + if (shutdown_gracefully) { + SSL_shutdown(ssl); } std::lock_guard<std::mutex> guard(ctx_mutex); @@ -5650,22 +5847,7 @@ inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec) { - { - timeval tv; - tv.tv_sec = static_cast<long>(read_timeout_sec); - tv.tv_usec = static_cast<decltype(tv.tv_usec)>(read_timeout_usec); - - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), - sizeof(tv)); - } - { - timeval tv; - tv.tv_sec = static_cast<long>(write_timeout_sec); - tv.tv_usec = static_cast<decltype(tv.tv_usec)>(write_timeout_usec); - - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char *>(&tv), - sizeof(tv)); - } + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() {} @@ -5680,8 +5862,27 @@ inline bool SSLSocketStream::is_writable() const { } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0 || is_readable()) { + if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast<int>(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast<int>(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + while (err == SSL_ERROR_WANT_READ) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast<int>(size)); + } else if (is_readable()) { + ret = SSL_read(ssl_, ptr, static_cast<int>(size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; } return -1; } @@ -5788,9 +5989,12 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { }); detail::ssl_delete(ctx_mutex_, ssl, ret); + detail::shutdown_socket(sock); + detail::close_socket(sock); return ret; } + detail::shutdown_socket(sock); detail::close_socket(sock); return false; } @@ -5843,6 +6047,10 @@ inline SSLClient::SSLClient(const std::string &host, int port, inline SSLClient::~SSLClient() { if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + SSLClient::shutdown_ssl(socket_, true); } inline bool SSLClient::is_valid() const { return ctx_; } @@ -5876,11 +6084,11 @@ inline bool SSLClient::create_and_connect_socket(Socket &socket) { return is_valid() && ClientImpl::create_and_connect_socket(socket); } +// Assumes that socket_mutex_ is locked and that there are no requests in flight inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success) { success = true; Response res2; - if (!detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { @@ -5889,7 +6097,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, req2.path = host_and_port_; return process_request(strm, req2, res2, false); })) { - close_socket(socket, true); + // Thread-safe to close everything because we are assuming there are no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); success = false; return false; } @@ -5912,7 +6123,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, true)); return process_request(strm, req3, res3, false); })) { - close_socket(socket, true); + // Thread-safe to close everything because we are assuming there are no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); success = false; return false; } @@ -6005,26 +6219,30 @@ inline bool SSLClient::initialize_ssl(Socket &socket) { return true; } - close_socket(socket, false); + shutdown_socket(socket); + close_socket(socket); return false; } -inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { - detail::close_socket(socket.sock); - socket_.sock = INVALID_SOCKET; +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret); - socket_.ssl = nullptr; + detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + socket.ssl = nullptr; } + assert(socket.ssl == nullptr); } inline bool -SSLClient::process_socket(Socket &socket, +SSLClient::process_socket(const Socket &socket, std::function<bool(Stream &strm)> callback) { assert(socket.ssl); return detail::process_client_socket_ssl( socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, callback); + write_timeout_sec_, write_timeout_usec_, std::move(callback)); } inline bool SSLClient::is_ssl() const { return true; } @@ -6175,6 +6393,8 @@ inline Client::Client(const char *scheme_host_port, #else if (!scheme.empty() && scheme != "http") { #endif + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); return; } @@ -6187,28 +6407,28 @@ inline Client::Client(const char *scheme_host_port, if (is_ssl) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - cli_ = std::make_shared<SSLClient>(host.c_str(), port, client_cert_path, - client_key_path); + cli_ = detail::make_unique<SSLClient>(host.c_str(), port, + client_cert_path, client_key_path); is_ssl_ = is_ssl; #endif } else { - cli_ = std::make_shared<ClientImpl>(host.c_str(), port, client_cert_path, - client_key_path); + cli_ = detail::make_unique<ClientImpl>(host.c_str(), port, + client_cert_path, client_key_path); } } else { - cli_ = std::make_shared<ClientImpl>(scheme_host_port, 80, client_cert_path, - client_key_path); + cli_ = detail::make_unique<ClientImpl>(scheme_host_port, 80, + client_cert_path, client_key_path); } } inline Client::Client(const std::string &host, int port) - : cli_(std::make_shared<ClientImpl>(host, port)) {} + : cli_(detail::make_unique<ClientImpl>(host, port)) {} inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : cli_(std::make_shared<ClientImpl>(host, port, client_cert_path, - client_key_path)) {} + : cli_(detail::make_unique<ClientImpl>(host, port, client_cert_path, + client_key_path)) {} inline Client::~Client() {} @@ -6221,11 +6441,11 @@ inline Result Client::Get(const char *path, const Headers &headers) { return cli_->Get(path, headers); } inline Result Client::Get(const char *path, Progress progress) { - return cli_->Get(path, progress); + return cli_->Get(path, std::move(progress)); } inline Result Client::Get(const char *path, const Headers &headers, Progress progress) { - return cli_->Get(path, headers, progress); + return cli_->Get(path, headers, std::move(progress)); } inline Result Client::Get(const char *path, ContentReceiver content_receiver) { return cli_->Get(path, std::move(content_receiver)); @@ -6262,7 +6482,8 @@ inline Result Client::Get(const char *path, ResponseHandler response_handler, inline Result Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { - return cli_->Get(path, headers, response_handler, content_receiver, progress); + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); } inline Result Client::Head(const char *path) { return cli_->Head(path); } @@ -6282,13 +6503,14 @@ inline Result Client::Post(const char *path, const Headers &headers, inline Result Client::Post(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, content_length, content_provider, content_type); + return cli_->Post(path, content_length, std::move(content_provider), + content_type); } inline Result Client::Post(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Post(path, headers, content_length, content_provider, + return cli_->Post(path, headers, content_length, std::move(content_provider), content_type); } inline Result Client::Post(const char *path, const Params ¶ms) { @@ -6306,6 +6528,11 @@ inline Result Client::Post(const char *path, const Headers &headers, const MultipartFormDataItems &items) { return cli_->Post(path, headers, items); } +inline Result Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} inline Result Client::Put(const char *path) { return cli_->Put(path); } inline Result Client::Put(const char *path, const std::string &body, const char *content_type) { @@ -6318,13 +6545,14 @@ inline Result Client::Put(const char *path, const Headers &headers, inline Result Client::Put(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, content_length, content_provider, content_type); + return cli_->Put(path, content_length, std::move(content_provider), + content_type); } inline Result Client::Put(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Put(path, headers, content_length, content_provider, + return cli_->Put(path, headers, content_length, std::move(content_provider), content_type); } inline Result Client::Put(const char *path, const Params ¶ms) { @@ -6345,13 +6573,14 @@ inline Result Client::Patch(const char *path, const Headers &headers, inline Result Client::Patch(const char *path, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, content_length, content_provider, content_type); + return cli_->Patch(path, content_length, std::move(content_provider), + content_type); } inline Result Client::Patch(const char *path, const Headers &headers, size_t content_length, ContentProvider content_provider, const char *content_type) { - return cli_->Patch(path, headers, content_length, content_provider, + return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type); } inline Result Client::Delete(const char *path) { return cli_->Delete(path); } @@ -6386,7 +6615,7 @@ inline void Client::set_default_headers(Headers headers) { inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } inline void Client::set_socket_options(SocketOptions socket_options) { - cli_->set_socket_options(socket_options); + cli_->set_socket_options(std::move(socket_options)); } inline void Client::set_connection_timeout(time_t sec, time_t usec) { diff --git a/src/web_service/web_backend.cpp b/src/web_service/web_backend.cpp index f264b98a0..67183e64c 100644 --- a/src/web_service/web_backend.cpp +++ b/src/web_service/web_backend.cpp @@ -71,11 +71,6 @@ struct Client::Impl { return {}; } - if (!cli->is_socket_open()) { - LOG_ERROR(WebService, "Failed to open socket, skipping request!"); - return {}; - } - cli->set_connection_timeout(TIMEOUT_SECONDS); cli->set_read_timeout(TIMEOUT_SECONDS); cli->set_write_timeout(TIMEOUT_SECONDS); |