diff options
author | Nicolas Werner <nicolas.werner@hotmail.de> | 2020-05-09 23:31:00 +0200 |
---|---|---|
committer | Nicolas Werner <nicolas.werner@hotmail.de> | 2020-05-09 23:33:03 +0200 |
commit | 7b1fa60cc6c74a53de0af636fa8f4f06caf87fa0 (patch) | |
tree | 30a2b6247c7b7be8c03995d0474f87df7019bb29 | |
parent | 813790e6031d0fef07be73c489ebaf59d3d1af4b (diff) |
Add SSO
closes #94
-rw-r--r-- | CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/ChatPage.cpp | 8 | ||||
-rw-r--r-- | src/LoginPage.cpp | 113 | ||||
-rw-r--r-- | src/LoginPage.h | 11 | ||||
-rw-r--r-- | src/SSOHandler.cpp | 54 | ||||
-rw-r--r-- | src/SSOHandler.h | 24 | ||||
-rw-r--r-- | third_party/cpp-httplib-0.5.12/httplib.h | 5125 |
7 files changed, 5311 insertions, 28 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 10a49dce..97cb8ea2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,6 +294,7 @@ set(SRC_FILES src/RegisterPage.cpp src/RoomInfoListItem.cpp src/RoomList.cpp + src/SSOHandler.cpp src/SideBarActions.cpp src/Splitter.cpp src/TextInputWidget.cpp @@ -493,6 +494,7 @@ qt5_wrap_cpp(MOC_HEADERS src/RegisterPage.h src/RoomInfoListItem.h src/RoomList.h + src/SSOHandler.h src/SideBarActions.h src/Splitter.h src/TextInputWidget.h @@ -556,7 +558,7 @@ elseif(WIN32) else() target_link_libraries (nheko PRIVATE Qt5::DBus) endif() -target_include_directories(nheko PRIVATE src includes third_party/blurhash) +target_include_directories(nheko PRIVATE src includes third_party/blurhash third_party/cpp-httplib-0.5.12) target_link_libraries(nheko PRIVATE MatrixClient::MatrixClient diff --git a/src/ChatPage.cpp b/src/ChatPage.cpp index ae3c7a11..7c4aac77 100644 --- a/src/ChatPage.cpp +++ b/src/ChatPage.cpp @@ -988,8 +988,12 @@ ChatPage::trySync() const auto err_code = mtx::errors::to_string(err->matrix_error.errcode); const int status_code = static_cast<int>(err->status_code); - if (http::is_logged_in() && err->matrix_error.errcode == - mtx::errors::ErrorCode::M_UNKNOWN_TOKEN) { + if ((http::is_logged_in() && + (err->matrix_error.errcode == + mtx::errors::ErrorCode::M_UNKNOWN_TOKEN || + err->matrix_error.errcode == + mtx::errors::ErrorCode::M_MISSING_TOKEN)) || + !http::is_logged_in()) { emit dropToLoginPageCb(msg); return; } diff --git a/src/LoginPage.cpp b/src/LoginPage.cpp index 6d96419a..4c3999ec 100644 --- a/src/LoginPage.cpp +++ b/src/LoginPage.cpp @@ -15,28 +15,35 @@ * along with this program. If not, see <http://www.gnu.org/licenses/>. */ +#include <QDesktopServices> #include <QPainter> #include <QStyleOption> #include <mtx/identifiers.hpp> +#include <mtx/requests.hpp> #include <mtx/responses/login.hpp> #include "Config.h" #include "Logging.h" #include "LoginPage.h" #include "MatrixClient.h" +#include "SSOHandler.h" #include "ui/FlatButton.h" #include "ui/LoadingIndicator.h" #include "ui/OverlayModal.h" #include "ui/RaisedButton.h" #include "ui/TextField.h" +Q_DECLARE_METATYPE(LoginPage::LoginMethod) + using namespace mtx::identifiers; LoginPage::LoginPage(QWidget *parent) : QWidget(parent) , inferredServerAddress_() { + qRegisterMetaType<LoginPage::LoginMethod>("LoginPage::LoginMethod"); + top_layout_ = new QVBoxLayout(); top_bar_layout_ = new QHBoxLayout(); @@ -226,7 +233,8 @@ LoginPage::onMatrixIdEntered() emit versionErrorCb(tr("Autodiscovery failed. Unknown error when " "requesting .well-known.")); nhlog::net()->error("Autodiscovery failed. Unknown error when " - "requesting .well-known."); + "requesting .well-known. {}", + err->status_code); return; } @@ -263,7 +271,16 @@ LoginPage::checkHomeserverVersion() return; } - emit versionOkCb(); + http::client()->get_login( + [this](mtx::responses::LoginFlows flows, mtx::http::RequestErr err) { + if (err || flows.flows.empty()) + emit versionOkCb(LoginMethod::Password); + + if (flows.flows[0].type == mtx::user_interactive::auth_types::sso) + emit versionOkCb(LoginMethod::SSO); + else + emit versionOkCb(LoginMethod::Password); + }); }); } @@ -294,12 +311,22 @@ LoginPage::versionError(const QString &error) } void -LoginPage::versionOk() +LoginPage::versionOk(LoginMethod loginMethod) { + this->loginMethod = loginMethod; + serverLayout_->removeWidget(spinner_); matrixidLayout_->removeWidget(spinner_); spinner_->stop(); + if (loginMethod == LoginMethod::SSO) { + password_input_->hide(); + login_button_->setText(tr("SSO LOGIN")); + } else { + password_input_->show(); + login_button_->setText(tr("LOGIN")); + } + if (serverInput_->isVisible()) serverInput_->hide(); } @@ -317,29 +344,68 @@ LoginPage::onLoginButtonClicked() return loginError("You have entered an invalid Matrix ID e.g @joe:matrix.org"); } - if (password_input_->text().isEmpty()) - return loginError(tr("Empty password")); + if (loginMethod == LoginMethod::Password) { + if (password_input_->text().isEmpty()) + return loginError(tr("Empty password")); + + http::client()->login( + user.localpart(), + password_input_->text().toStdString(), + deviceName_->text().trimmed().isEmpty() ? initialDeviceName() + : deviceName_->text().toStdString(), + [this](const mtx::responses::Login &res, mtx::http::RequestErr err) { + if (err) { + emit loginError(QString::fromStdString(err->matrix_error.error)); + emit errorOccurred(); + return; + } - http::client()->login( - user.localpart(), - password_input_->text().toStdString(), - deviceName_->text().trimmed().isEmpty() ? initialDeviceName() - : deviceName_->text().toStdString(), - [this](const mtx::responses::Login &res, mtx::http::RequestErr err) { - if (err) { - emit loginError(QString::fromStdString(err->matrix_error.error)); - emit errorOccurred(); - return; - } + if (res.well_known) { + http::client()->set_server(res.well_known->homeserver.base_url); + nhlog::net()->info("Login requested to user server: " + + res.well_known->homeserver.base_url); + } - if (res.well_known) { - http::client()->set_server(res.well_known->homeserver.base_url); - nhlog::net()->info("Login requested to user server: " + - res.well_known->homeserver.base_url); - } + emit loginOk(res); + }); + } else { + auto sso = new SSOHandler(); + connect(sso, &SSOHandler::ssoSuccess, this, [this, sso](std::string token) { + mtx::requests::Login req{}; + req.token = token; + req.type = mtx::user_interactive::auth_types::token; + req.device_id = deviceName_->text().trimmed().isEmpty() + ? initialDeviceName() + : deviceName_->text().toStdString(); + http::client()->login( + req, [this](const mtx::responses::Login &res, mtx::http::RequestErr err) { + if (err) { + emit loginError( + QString::fromStdString(err->matrix_error.error)); + emit errorOccurred(); + return; + } + + if (res.well_known) { + http::client()->set_server( + res.well_known->homeserver.base_url); + nhlog::net()->info("Login requested to user server: " + + res.well_known->homeserver.base_url); + } + + emit loginOk(res); + }); + sso->deleteLater(); + }); + connect(sso, &SSOHandler::ssoFailed, this, [this, sso]() { + emit loginError(tr("SSO login failed")); + emit errorOccurred(); + sso->deleteLater(); + }); - emit loginOk(res); - }); + QDesktopServices::openUrl( + QString::fromStdString(http::client()->login_sso_redirect(sso->url()))); + } emit loggingIn(); } @@ -349,6 +415,7 @@ LoginPage::reset() { matrixid_input_->clear(); password_input_->clear(); + password_input_->show(); serverInput_->clear(); spinner_->stop(); diff --git a/src/LoginPage.h b/src/LoginPage.h index 4b84abfc..8a402aea 100644 --- a/src/LoginPage.h +++ b/src/LoginPage.h @@ -38,6 +38,12 @@ class LoginPage : public QWidget Q_OBJECT public: + enum class LoginMethod + { + Password, + SSO, + }; + LoginPage(QWidget *parent = nullptr); void reset(); @@ -50,7 +56,7 @@ signals: //! Used to trigger the corresponding slot outside of the main thread. void versionErrorCb(const QString &err); void loginErrorCb(const QString &err); - void versionOkCb(); + void versionOkCb(LoginPage::LoginMethod method); void loginOk(const mtx::responses::Login &res); @@ -77,7 +83,7 @@ private slots: // Callback for errors produced during server probing void versionError(const QString &error_message); // Callback for successful server probing - void versionOk(); + void versionOk(LoginPage::LoginMethod method); private: bool isMatrixIdValid(); @@ -123,4 +129,5 @@ private: TextField *password_input_; TextField *deviceName_; TextField *serverInput_; + LoginMethod loginMethod = LoginMethod::Password; }; diff --git a/src/SSOHandler.cpp b/src/SSOHandler.cpp new file mode 100644 index 00000000..0ee2fc17 --- /dev/null +++ b/src/SSOHandler.cpp @@ -0,0 +1,54 @@ +#include "SSOHandler.h" + +#include <QTimer> + +#include <thread> + +#include "Logging.h" + +SSOHandler::SSOHandler(QObject *) +{ + QTimer::singleShot(120000, this, &SSOHandler::ssoFailed); + + using namespace httplib; + + svr.set_logger([](const Request &req, const Response &res) { + nhlog::net()->info("req: {}, res: {}", req.path, res.status); + }); + + svr.Get("/sso", [this](const Request &req, Response &res) { + if (req.has_param("loginToken")) { + auto val = req.get_param_value("loginToken"); + res.set_content("SSO success", "text/plain"); + emit ssoSuccess(val); + } else { + res.set_content("Missing loginToken for SSO login!", "text/plain"); + emit ssoFailed(); + } + }); + + std::thread t([this]() { + this->port = svr.bind_to_any_port("localhost"); + svr.listen_after_bind(); + + }); + t.detach(); + + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +SSOHandler::~SSOHandler() +{ + svr.stop(); + while (svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +std::string +SSOHandler::url() const +{ + return "http://localhost:" + std::to_string(port) + "/sso"; +} diff --git a/src/SSOHandler.h b/src/SSOHandler.h new file mode 100644 index 00000000..325b7a58 --- /dev/null +++ b/src/SSOHandler.h @@ -0,0 +1,24 @@ +#include "httplib.h" + +#include <QObject> +#include <string> + +class SSOHandler : public QObject +{ + Q_OBJECT + +public: + SSOHandler(QObject *parent = nullptr); + + ~SSOHandler(); + + std::string url() const; + +signals: + void ssoSuccess(std::string token); + void ssoFailed(); + +private: + httplib::Server svr; + int port = 0; +}; diff --git a/third_party/cpp-httplib-0.5.12/httplib.h b/third_party/cpp-httplib-0.5.12/httplib.h new file mode 100644 index 00000000..7816df8b --- /dev/null +++ b/third_party/cpp-httplib-0.5.12/httplib.h @@ -0,0 +1,5125 @@ +// +// httplib.h +// +// Copyright (c) 2020 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits<size_t>::max)()) +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() - 1)) +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; +#endif + +#if _MSC_VER < 1900 +#define snprintf _snprintf_s +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include <io.h> +#include <winsock2.h> +#include <ws2tcpip.h> + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif + +#ifndef strcasecmp +#define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include <arpa/inet.h> +#include <cstring> +#include <ifaddrs.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef CPPHTTPLIB_USE_POLL +#include <poll.h> +#endif +#include <csignal> +#include <pthread.h> +#include <sys/select.h> +#include <sys/socket.h> +#include <unistd.h> + +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include <array> +#include <atomic> +#include <cassert> +#include <climits> +#include <condition_variable> +#include <errno.h> +#include <fcntl.h> +#include <fstream> +#include <functional> +#include <list> +#include <map> +#include <memory> +#include <mutex> +#include <random> +#include <regex> +#include <string> +#include <sys/stat.h> +#include <thread> + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include <openssl/err.h> +#include <openssl/md5.h> +#include <openssl/ssl.h> +#include <openssl/x509v3.h> + +#include <iomanip> +#include <iostream> +#include <sstream> + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include <openssl/crypto.h> +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include <zlib.h> +#endif +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } +}; + +} // namespace detail + +using Headers = std::multimap<std::string, std::string, detail::ci>; + +using Params = std::multimap<std::string, std::string>; +using Match = std::smatch; + +using Progress = std::function<bool(uint64_t current, uint64_t total)>; + +struct Response; +using ResponseHandler = std::function<bool(const Response &response)>; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector<MultipartFormData>; +using MultipartFormDataMap = std::multimap<std::string, MultipartFormData>; + +class DataSink { +public: + DataSink() = default; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function<void(const char *data, size_t data_len)> write; + std::function<void()> done; + std::function<bool()> is_writable; +}; + +using ContentProvider = + std::function<void(size_t offset, size_t length, DataSink &sink)>; + +using ContentReceiver = + std::function<bool(const char *data, size_t data_length)>; + +using MultipartContentHeader = + std::function<bool(const MultipartFormData &file)>; + +class ContentReader { +public: + using Reader = std::function<bool(ContentReceiver receiver)>; + using MultipartReader = std::function<bool(MultipartContentHeader header, + ContentReceiver receiver)>; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; +}; + +using Range = std::pair<ssize_t, ssize_t>; +using Ranges = std::vector<Range>; + +struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; +}; + +struct Response { + std::string version; + int status = -1; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, + std::function<void(size_t offset, size_t length, DataSink &sink)> + provider, + std::function<void()> resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function<void(size_t offset, DataSink &sink)> provider, + std::function<void()> resource_releaser = [] {}); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function<void()> content_provider_resource_releaser; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + + template <typename... Args> + ssize_t write_format(const char *fmt, const Args &... args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual void enqueue(std::function<void()> fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle(){}; +}; + +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function<void()> fn) override { + std::unique_lock<std::mutex> lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock<std::mutex> lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function<void()> fn; + { + std::unique_lock<std::mutex> lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast<bool>(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector<std::thread> threads_; + std::list<std::function<void()>> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function<void(const Request &, const Response &)>; + +class Server { +public: + using Handler = std::function<void(const Request &, Response &)>; + using HandlerWithContentReader = std::function<void( + const Request &, Response &, const ContentReader &content_reader)>; + using Expect100ContinueHandler = + std::function<int(const Request &, Response &)>; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); + + [[deprecated]] bool set_base_dir(const char *dir, + const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); + + void set_error_handler(Handler handler); + void set_logger(Logger logger); + + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); + + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const char *host, int port, int socket_flags = 0); + + bool is_running() const; + void stop(); + + std::function<TaskQueue *(void)> new_task_queue; + +protected: + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function<void(Request &)> &setup_request); + + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; + +private: + using Handlers = std::vector<std::pair<std::regex, Handler>>; + using HandlersForContentReader = + std::vector<std::pair<std::regex, HandlerWithContentReader>>; + + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic<bool> is_running_; + std::atomic<socket_t> svr_sock_; + std::vector<std::pair<std::string, std::string>> base_dirs_; + std::map<std::string, std::string> file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; +}; + +class Client { +public: + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); + + virtual ~Client(); + + virtual bool is_valid() const; + + std::shared_ptr<Response> Get(const char *path); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers); + + std::shared_ptr<Response> Get(const char *path, Progress progress); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + Progress progress); + + std::shared_ptr<Response> Get(const char *path, + ContentReceiver content_receiver); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + + std::shared_ptr<Response> + Get(const char *path, ContentReceiver content_receiver, Progress progress); + + std::shared_ptr<Response> Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr<Response> Get(const char *path, const Header |