summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Werner <nicolas.werner@hotmail.de>2020-05-09 23:31:00 +0200
committerNicolas Werner <nicolas.werner@hotmail.de>2020-05-09 23:33:03 +0200
commit7b1fa60cc6c74a53de0af636fa8f4f06caf87fa0 (patch)
tree30a2b6247c7b7be8c03995d0474f87df7019bb29
parent813790e6031d0fef07be73c489ebaf59d3d1af4b (diff)
Add SSO
closes #94
-rw-r--r--CMakeLists.txt4
-rw-r--r--src/ChatPage.cpp8
-rw-r--r--src/LoginPage.cpp113
-rw-r--r--src/LoginPage.h11
-rw-r--r--src/SSOHandler.cpp54
-rw-r--r--src/SSOHandler.h24
-rw-r--r--third_party/cpp-httplib-0.5.12/httplib.h5125
7 files changed, 5311 insertions, 28 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 10a49dce..97cb8ea2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -294,6 +294,7 @@ set(SRC_FILES
src/RegisterPage.cpp
src/RoomInfoListItem.cpp
src/RoomList.cpp
+ src/SSOHandler.cpp
src/SideBarActions.cpp
src/Splitter.cpp
src/TextInputWidget.cpp
@@ -493,6 +494,7 @@ qt5_wrap_cpp(MOC_HEADERS
src/RegisterPage.h
src/RoomInfoListItem.h
src/RoomList.h
+ src/SSOHandler.h
src/SideBarActions.h
src/Splitter.h
src/TextInputWidget.h
@@ -556,7 +558,7 @@ elseif(WIN32)
else()
target_link_libraries (nheko PRIVATE Qt5::DBus)
endif()
-target_include_directories(nheko PRIVATE src includes third_party/blurhash)
+target_include_directories(nheko PRIVATE src includes third_party/blurhash third_party/cpp-httplib-0.5.12)
target_link_libraries(nheko PRIVATE
MatrixClient::MatrixClient
diff --git a/src/ChatPage.cpp b/src/ChatPage.cpp
index ae3c7a11..7c4aac77 100644
--- a/src/ChatPage.cpp
+++ b/src/ChatPage.cpp
@@ -988,8 +988,12 @@ ChatPage::trySync()
const auto err_code = mtx::errors::to_string(err->matrix_error.errcode);
const int status_code = static_cast<int>(err->status_code);
- if (http::is_logged_in() && err->matrix_error.errcode ==
- mtx::errors::ErrorCode::M_UNKNOWN_TOKEN) {
+ if ((http::is_logged_in() &&
+ (err->matrix_error.errcode ==
+ mtx::errors::ErrorCode::M_UNKNOWN_TOKEN ||
+ err->matrix_error.errcode ==
+ mtx::errors::ErrorCode::M_MISSING_TOKEN)) ||
+ !http::is_logged_in()) {
emit dropToLoginPageCb(msg);
return;
}
diff --git a/src/LoginPage.cpp b/src/LoginPage.cpp
index 6d96419a..4c3999ec 100644
--- a/src/LoginPage.cpp
+++ b/src/LoginPage.cpp
@@ -15,28 +15,35 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+#include <QDesktopServices>
#include <QPainter>
#include <QStyleOption>
#include <mtx/identifiers.hpp>
+#include <mtx/requests.hpp>
#include <mtx/responses/login.hpp>
#include "Config.h"
#include "Logging.h"
#include "LoginPage.h"
#include "MatrixClient.h"
+#include "SSOHandler.h"
#include "ui/FlatButton.h"
#include "ui/LoadingIndicator.h"
#include "ui/OverlayModal.h"
#include "ui/RaisedButton.h"
#include "ui/TextField.h"
+Q_DECLARE_METATYPE(LoginPage::LoginMethod)
+
using namespace mtx::identifiers;
LoginPage::LoginPage(QWidget *parent)
: QWidget(parent)
, inferredServerAddress_()
{
+ qRegisterMetaType<LoginPage::LoginMethod>("LoginPage::LoginMethod");
+
top_layout_ = new QVBoxLayout();
top_bar_layout_ = new QHBoxLayout();
@@ -226,7 +233,8 @@ LoginPage::onMatrixIdEntered()
emit versionErrorCb(tr("Autodiscovery failed. Unknown error when "
"requesting .well-known."));
nhlog::net()->error("Autodiscovery failed. Unknown error when "
- "requesting .well-known.");
+ "requesting .well-known. {}",
+ err->status_code);
return;
}
@@ -263,7 +271,16 @@ LoginPage::checkHomeserverVersion()
return;
}
- emit versionOkCb();
+ http::client()->get_login(
+ [this](mtx::responses::LoginFlows flows, mtx::http::RequestErr err) {
+ if (err || flows.flows.empty())
+ emit versionOkCb(LoginMethod::Password);
+
+ if (flows.flows[0].type == mtx::user_interactive::auth_types::sso)
+ emit versionOkCb(LoginMethod::SSO);
+ else
+ emit versionOkCb(LoginMethod::Password);
+ });
});
}
@@ -294,12 +311,22 @@ LoginPage::versionError(const QString &error)
}
void
-LoginPage::versionOk()
+LoginPage::versionOk(LoginMethod loginMethod)
{
+ this->loginMethod = loginMethod;
+
serverLayout_->removeWidget(spinner_);
matrixidLayout_->removeWidget(spinner_);
spinner_->stop();
+ if (loginMethod == LoginMethod::SSO) {
+ password_input_->hide();
+ login_button_->setText(tr("SSO LOGIN"));
+ } else {
+ password_input_->show();
+ login_button_->setText(tr("LOGIN"));
+ }
+
if (serverInput_->isVisible())
serverInput_->hide();
}
@@ -317,29 +344,68 @@ LoginPage::onLoginButtonClicked()
return loginError("You have entered an invalid Matrix ID e.g @joe:matrix.org");
}
- if (password_input_->text().isEmpty())
- return loginError(tr("Empty password"));
+ if (loginMethod == LoginMethod::Password) {
+ if (password_input_->text().isEmpty())
+ return loginError(tr("Empty password"));
+
+ http::client()->login(
+ user.localpart(),
+ password_input_->text().toStdString(),
+ deviceName_->text().trimmed().isEmpty() ? initialDeviceName()
+ : deviceName_->text().toStdString(),
+ [this](const mtx::responses::Login &res, mtx::http::RequestErr err) {
+ if (err) {
+ emit loginError(QString::fromStdString(err->matrix_error.error));
+ emit errorOccurred();
+ return;
+ }
- http::client()->login(
- user.localpart(),
- password_input_->text().toStdString(),
- deviceName_->text().trimmed().isEmpty() ? initialDeviceName()
- : deviceName_->text().toStdString(),
- [this](const mtx::responses::Login &res, mtx::http::RequestErr err) {
- if (err) {
- emit loginError(QString::fromStdString(err->matrix_error.error));
- emit errorOccurred();
- return;
- }
+ if (res.well_known) {
+ http::client()->set_server(res.well_known->homeserver.base_url);
+ nhlog::net()->info("Login requested to user server: " +
+ res.well_known->homeserver.base_url);
+ }
- if (res.well_known) {
- http::client()->set_server(res.well_known->homeserver.base_url);
- nhlog::net()->info("Login requested to user server: " +
- res.well_known->homeserver.base_url);
- }
+ emit loginOk(res);
+ });
+ } else {
+ auto sso = new SSOHandler();
+ connect(sso, &SSOHandler::ssoSuccess, this, [this, sso](std::string token) {
+ mtx::requests::Login req{};
+ req.token = token;
+ req.type = mtx::user_interactive::auth_types::token;
+ req.device_id = deviceName_->text().trimmed().isEmpty()
+ ? initialDeviceName()
+ : deviceName_->text().toStdString();
+ http::client()->login(
+ req, [this](const mtx::responses::Login &res, mtx::http::RequestErr err) {
+ if (err) {
+ emit loginError(
+ QString::fromStdString(err->matrix_error.error));
+ emit errorOccurred();
+ return;
+ }
+
+ if (res.well_known) {
+ http::client()->set_server(
+ res.well_known->homeserver.base_url);
+ nhlog::net()->info("Login requested to user server: " +
+ res.well_known->homeserver.base_url);
+ }
+
+ emit loginOk(res);
+ });
+ sso->deleteLater();
+ });
+ connect(sso, &SSOHandler::ssoFailed, this, [this, sso]() {
+ emit loginError(tr("SSO login failed"));
+ emit errorOccurred();
+ sso->deleteLater();
+ });
- emit loginOk(res);
- });
+ QDesktopServices::openUrl(
+ QString::fromStdString(http::client()->login_sso_redirect(sso->url())));
+ }
emit loggingIn();
}
@@ -349,6 +415,7 @@ LoginPage::reset()
{
matrixid_input_->clear();
password_input_->clear();
+ password_input_->show();
serverInput_->clear();
spinner_->stop();
diff --git a/src/LoginPage.h b/src/LoginPage.h
index 4b84abfc..8a402aea 100644
--- a/src/LoginPage.h
+++ b/src/LoginPage.h
@@ -38,6 +38,12 @@ class LoginPage : public QWidget
Q_OBJECT
public:
+ enum class LoginMethod
+ {
+ Password,
+ SSO,
+ };
+
LoginPage(QWidget *parent = nullptr);
void reset();
@@ -50,7 +56,7 @@ signals:
//! Used to trigger the corresponding slot outside of the main thread.
void versionErrorCb(const QString &err);
void loginErrorCb(const QString &err);
- void versionOkCb();
+ void versionOkCb(LoginPage::LoginMethod method);
void loginOk(const mtx::responses::Login &res);
@@ -77,7 +83,7 @@ private slots:
// Callback for errors produced during server probing
void versionError(const QString &error_message);
// Callback for successful server probing
- void versionOk();
+ void versionOk(LoginPage::LoginMethod method);
private:
bool isMatrixIdValid();
@@ -123,4 +129,5 @@ private:
TextField *password_input_;
TextField *deviceName_;
TextField *serverInput_;
+ LoginMethod loginMethod = LoginMethod::Password;
};
diff --git a/src/SSOHandler.cpp b/src/SSOHandler.cpp
new file mode 100644
index 00000000..0ee2fc17
--- /dev/null
+++ b/src/SSOHandler.cpp
@@ -0,0 +1,54 @@
+#include "SSOHandler.h"
+
+#include <QTimer>
+
+#include <thread>
+
+#include "Logging.h"
+
+SSOHandler::SSOHandler(QObject *)
+{
+ QTimer::singleShot(120000, this, &SSOHandler::ssoFailed);
+
+ using namespace httplib;
+
+ svr.set_logger([](const Request &req, const Response &res) {
+ nhlog::net()->info("req: {}, res: {}", req.path, res.status);
+ });
+
+ svr.Get("/sso", [this](const Request &req, Response &res) {
+ if (req.has_param("loginToken")) {
+ auto val = req.get_param_value("loginToken");
+ res.set_content("SSO success", "text/plain");
+ emit ssoSuccess(val);
+ } else {
+ res.set_content("Missing loginToken for SSO login!", "text/plain");
+ emit ssoFailed();
+ }
+ });
+
+ std::thread t([this]() {
+ this->port = svr.bind_to_any_port("localhost");
+ svr.listen_after_bind();
+
+ });
+ t.detach();
+
+ while (!svr.is_running()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+}
+
+SSOHandler::~SSOHandler()
+{
+ svr.stop();
+ while (svr.is_running()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+}
+
+std::string
+SSOHandler::url() const
+{
+ return "http://localhost:" + std::to_string(port) + "/sso";
+}
diff --git a/src/SSOHandler.h b/src/SSOHandler.h
new file mode 100644
index 00000000..325b7a58
--- /dev/null
+++ b/src/SSOHandler.h
@@ -0,0 +1,24 @@
+#include "httplib.h"
+
+#include <QObject>
+#include <string>
+
+class SSOHandler : public QObject
+{
+ Q_OBJECT
+
+public:
+ SSOHandler(QObject *parent = nullptr);
+
+ ~SSOHandler();
+
+ std::string url() const;
+
+signals:
+ void ssoSuccess(std::string token);
+ void ssoFailed();
+
+private:
+ httplib::Server svr;
+ int port = 0;
+};
diff --git a/third_party/cpp-httplib-0.5.12/httplib.h b/third_party/cpp-httplib-0.5.12/httplib.h
new file mode 100644
index 00000000..7816df8b
--- /dev/null
+++ b/third_party/cpp-httplib-0.5.12/httplib.h
@@ -0,0 +1,5125 @@
+//
+// httplib.h
+//
+// Copyright (c) 2020 Yuji Hirose. All rights reserved.
+// MIT License
+//
+
+#ifndef CPPHTTPLIB_HTTPLIB_H
+#define CPPHTTPLIB_HTTPLIB_H
+
+/*
+ * Configuration
+ */
+
+#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND
+#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
+#endif
+
+#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND
+#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0
+#endif
+
+#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT
+#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5
+#endif
+
+#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND
+#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5
+#endif
+
+#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND
+#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0
+#endif
+
+#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH
+#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
+#endif
+
+#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT
+#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20
+#endif
+
+#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH
+#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits<size_t>::max)())
+#endif
+
+#ifndef CPPHTTPLIB_RECV_BUFSIZ
+#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u)
+#endif
+
+#ifndef CPPHTTPLIB_THREAD_POOL_COUNT
+#define CPPHTTPLIB_THREAD_POOL_COUNT \
+ ((std::max)(8u, std::thread::hardware_concurrency() - 1))
+#endif
+
+/*
+ * Headers
+ */
+
+#ifdef _WIN32
+#ifndef _CRT_SECURE_NO_WARNINGS
+#define _CRT_SECURE_NO_WARNINGS
+#endif //_CRT_SECURE_NO_WARNINGS
+
+#ifndef _CRT_NONSTDC_NO_DEPRECATE
+#define _CRT_NONSTDC_NO_DEPRECATE
+#endif //_CRT_NONSTDC_NO_DEPRECATE
+
+#if defined(_MSC_VER)
+#ifdef _WIN64
+using ssize_t = __int64;
+#else
+using ssize_t = int;
+#endif
+
+#if _MSC_VER < 1900
+#define snprintf _snprintf_s
+#endif
+#endif // _MSC_VER
+
+#ifndef S_ISREG
+#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG)
+#endif // S_ISREG
+
+#ifndef S_ISDIR
+#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR)
+#endif // S_ISDIR
+
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif // NOMINMAX
+
+#include <io.h>
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#ifndef WSA_FLAG_NO_HANDLE_INHERIT
+#define WSA_FLAG_NO_HANDLE_INHERIT 0x80
+#endif
+
+#ifdef _MSC_VER
+#pragma comment(lib, "ws2_32.lib")
+#endif
+
+#ifndef strcasecmp
+#define strcasecmp _stricmp
+#endif // strcasecmp
+
+using socket_t = SOCKET;
+#ifdef CPPHTTPLIB_USE_POLL
+#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout)
+#endif
+
+#else // not _WIN32
+
+#include <arpa/inet.h>
+#include <cstring>
+#include <ifaddrs.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#ifdef CPPHTTPLIB_USE_POLL
+#include <poll.h>
+#endif
+#include <csignal>
+#include <pthread.h>
+#include <sys/select.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+using socket_t = int;
+#define INVALID_SOCKET (-1)
+#endif //_WIN32
+
+#include <array>
+#include <atomic>
+#include <cassert>
+#include <climits>
+#include <condition_variable>
+#include <errno.h>
+#include <fcntl.h>
+#include <fstream>
+#include <functional>
+#include <list>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <random>
+#include <regex>
+#include <string>
+#include <sys/stat.h>
+#include <thread>
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+#include <openssl/err.h>
+#include <openssl/md5.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+
+#include <iomanip>
+#include <iostream>
+#include <sstream>
+
+// #if OPENSSL_VERSION_NUMBER < 0x1010100fL
+// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported
+// #endif
+
+#if OPENSSL_VERSION_NUMBER < 0x10100000L
+#include <openssl/crypto.h>
+inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) {
+ return M_ASN1_STRING_data(asn1);
+}
+#endif
+#endif
+
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+#include <zlib.h>
+#endif
+/*
+ * Declaration
+ */
+namespace httplib {
+
+namespace detail {
+
+struct ci {
+ bool operator()(const std::string &s1, const std::string &s2) const {
+ return std::lexicographical_compare(
+ s1.begin(), s1.end(), s2.begin(), s2.end(),
+ [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); });
+ }
+};
+
+} // namespace detail
+
+using Headers = std::multimap<std::string, std::string, detail::ci>;
+
+using Params = std::multimap<std::string, std::string>;
+using Match = std::smatch;
+
+using Progress = std::function<bool(uint64_t current, uint64_t total)>;
+
+struct Response;
+using ResponseHandler = std::function<bool(const Response &response)>;
+
+struct MultipartFormData {
+ std::string name;
+ std::string content;
+ std::string filename;
+ std::string content_type;
+};
+using MultipartFormDataItems = std::vector<MultipartFormData>;
+using MultipartFormDataMap = std::multimap<std::string, MultipartFormData>;
+
+class DataSink {
+public:
+ DataSink() = default;
+ DataSink(const DataSink &) = delete;
+ DataSink &operator=(const DataSink &) = delete;
+ DataSink(DataSink &&) = delete;
+ DataSink &operator=(DataSink &&) = delete;
+
+ std::function<void(const char *data, size_t data_len)> write;
+ std::function<void()> done;
+ std::function<bool()> is_writable;
+};
+
+using ContentProvider =
+ std::function<void(size_t offset, size_t length, DataSink &sink)>;
+
+using ContentReceiver =
+ std::function<bool(const char *data, size_t data_length)>;
+
+using MultipartContentHeader =
+ std::function<bool(const MultipartFormData &file)>;
+
+class ContentReader {
+public:
+ using Reader = std::function<bool(ContentReceiver receiver)>;
+ using MultipartReader = std::function<bool(MultipartContentHeader header,
+ ContentReceiver receiver)>;
+
+ ContentReader(Reader reader, MultipartReader muitlpart_reader)
+ : reader_(reader), muitlpart_reader_(muitlpart_reader) {}
+
+ bool operator()(MultipartContentHeader header,
+ ContentReceiver receiver) const {
+ return muitlpart_reader_(header, receiver);
+ }
+
+ bool operator()(ContentReceiver receiver) const { return reader_(receiver); }
+
+ Reader reader_;
+ MultipartReader muitlpart_reader_;
+};
+
+using Range = std::pair<ssize_t, ssize_t>;
+using Ranges = std::vector<Range>;
+
+struct Request {
+ std::string method;
+ std::string path;
+ Headers headers;
+ std::string body;
+
+ std::string remote_addr;
+ int remote_port = -1;
+
+ // for server
+ std::string version;
+ std::string target;
+ Params params;
+ MultipartFormDataMap files;
+ Ranges ranges;
+ Match matches;
+
+ // for client
+ size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT;
+ ResponseHandler response_handler;
+ ContentReceiver content_receiver;
+ Progress progress;
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+ const SSL *ssl;
+#endif
+
+ bool has_header(const char *key) const;
+ std::string get_header_value(const char *key, size_t id = 0) const;
+ size_t get_header_value_count(const char *key) const;
+ void set_header(const char *key, const char *val);
+ void set_header(const char *key, const std::string &val);
+
+ bool has_param(const char *key) const;
+ std::string get_param_value(const char *key, size_t id = 0) const;
+ size_t get_param_value_count(const char *key) const;
+
+ bool is_multipart_form_data() const;
+
+ bool has_file(const char *key) const;
+ MultipartFormData get_file_value(const char *key) const;
+
+ // private members...
+ size_t content_length;
+ ContentProvider content_provider;
+};
+
+struct Response {
+ std::string version;
+ int status = -1;
+ Headers headers;
+ std::string body;
+
+ bool has_header(const char *key) const;
+ std::string get_header_value(const char *key, size_t id = 0) const;
+ size_t get_header_value_count(const char *key) const;
+ void set_header(const char *key, const char *val);
+ void set_header(const char *key, const std::string &val);
+
+ void set_redirect(const char *url, int status = 302);
+ void set_content(const char *s, size_t n, const char *content_type);
+ void set_content(std::string s, const char *content_type);
+
+ void set_content_provider(
+ size_t length,
+ std::function<void(size_t offset, size_t length, DataSink &sink)>
+ provider,
+ std::function<void()> resource_releaser = [] {});
+
+ void set_chunked_content_provider(
+ std::function<void(size_t offset, DataSink &sink)> provider,
+ std::function<void()> resource_releaser = [] {});
+
+ Response() = default;
+ Response(const Response &) = default;
+ Response &operator=(const Response &) = default;
+ Response(Response &&) = default;
+ Response &operator=(Response &&) = default;
+ ~Response() {
+ if (content_provider_resource_releaser) {
+ content_provider_resource_releaser();
+ }
+ }
+
+ // private members...
+ size_t content_length = 0;
+ ContentProvider content_provider;
+ std::function<void()> content_provider_resource_releaser;
+};
+
+class Stream {
+public:
+ virtual ~Stream() = default;
+
+ virtual bool is_readable() const = 0;
+ virtual bool is_writable() const = 0;
+
+ virtual ssize_t read(char *ptr, size_t size) = 0;
+ virtual ssize_t write(const char *ptr, size_t size) = 0;
+ virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
+
+ template <typename... Args>
+ ssize_t write_format(const char *fmt, const Args &... args);
+ ssize_t write(const char *ptr);
+ ssize_t write(const std::string &s);
+};
+
+class TaskQueue {
+public:
+ TaskQueue() = default;
+ virtual ~TaskQueue() = default;
+
+ virtual void enqueue(std::function<void()> fn) = 0;
+ virtual void shutdown() = 0;
+
+ virtual void on_idle(){};
+};
+
+class ThreadPool : public TaskQueue {
+public:
+ explicit ThreadPool(size_t n) : shutdown_(false) {
+ while (n) {
+ threads_.emplace_back(worker(*this));
+ n--;
+ }
+ }
+
+ ThreadPool(const ThreadPool &) = delete;
+ ~ThreadPool() override = default;
+
+ void enqueue(std::function<void()> fn) override {
+ std::unique_lock<std::mutex> lock(mutex_);
+ jobs_.push_back(fn);
+ cond_.notify_one();
+ }
+
+ void shutdown() override {
+ // Stop all worker threads...
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ shutdown_ = true;
+ }
+
+ cond_.notify_all();
+
+ // Join...
+ for (auto &t : threads_) {
+ t.join();
+ }
+ }
+
+private:
+ struct worker {
+ explicit worker(ThreadPool &pool) : pool_(pool) {}
+
+ void operator()() {
+ for (;;) {
+ std::function<void()> fn;
+ {
+ std::unique_lock<std::mutex> lock(pool_.mutex_);
+
+ pool_.cond_.wait(
+ lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; });
+
+ if (pool_.shutdown_ && pool_.jobs_.empty()) { break; }
+
+ fn = pool_.jobs_.front();
+ pool_.jobs_.pop_front();
+ }
+
+ assert(true == static_cast<bool>(fn));
+ fn();
+ }
+ }
+
+ ThreadPool &pool_;
+ };
+ friend struct worker;
+
+ std::vector<std::thread> threads_;
+ std::list<std::function<void()>> jobs_;
+
+ bool shutdown_;
+
+ std::condition_variable cond_;
+ std::mutex mutex_;
+};
+
+using Logger = std::function<void(const Request &, const Response &)>;
+
+class Server {
+public:
+ using Handler = std::function<void(const Request &, Response &)>;
+ using HandlerWithContentReader = std::function<void(
+ const Request &, Response &, const ContentReader &content_reader)>;
+ using Expect100ContinueHandler =
+ std::function<int(const Request &, Response &)>;
+
+ Server();
+
+ virtual ~Server();
+
+ virtual bool is_valid() const;
+
+ Server &Get(const char *pattern, Handler handler);
+ Server &Post(const char *pattern, Handler handler);
+ Server &Post(const char *pattern, HandlerWithContentReader handler);
+ Server &Put(const char *pattern, Handler handler);
+ Server &Put(const char *pattern, HandlerWithContentReader handler);
+ Server &Patch(const char *pattern, Handler handler);
+ Server &Patch(const char *pattern, HandlerWithContentReader handler);
+ Server &Delete(const char *pattern, Handler handler);
+ Server &Delete(const char *pattern, HandlerWithContentReader handler);
+ Server &Options(const char *pattern, Handler handler);
+
+ [[deprecated]] bool set_base_dir(const char *dir,
+ const char *mount_point = nullptr);
+ bool set_mount_point(const char *mount_point, const char *dir);
+ bool remove_mount_point(const char *mount_point);
+ void set_file_extension_and_mimetype_mapping(const char *ext,
+ const char *mime);
+ void set_file_request_handler(Handler handler);
+
+ void set_error_handler(Handler handler);
+ void set_logger(Logger logger);
+
+ void set_expect_100_continue_handler(Expect100ContinueHandler handler);
+
+ void set_keep_alive_max_count(size_t count);
+ void set_read_timeout(time_t sec, time_t usec);
+ void set_payload_max_length(size_t length);
+
+ bool bind_to_port(const char *host, int port, int socket_flags = 0);
+ int bind_to_any_port(const char *host, int socket_flags = 0);
+ bool listen_after_bind();
+
+ bool listen(const char *host, int port, int socket_flags = 0);
+
+ bool is_running() const;
+ void stop();
+
+ std::function<TaskQueue *(void)> new_task_queue;
+
+protected:
+ bool process_request(Stream &strm, bool last_connection,
+ bool &connection_close,
+ const std::function<void(Request &)> &setup_request);
+
+ size_t keep_alive_max_count_;
+ time_t read_timeout_sec_;
+ time_t read_timeout_usec_;
+ size_t payload_max_length_;
+
+private:
+ using Handlers = std::vector<std::pair<std::regex, Handler>>;
+ using HandlersForContentReader =
+ std::vector<std::pair<std::regex, HandlerWithContentReader>>;
+
+ socket_t create_server_socket(const char *host, int port,
+ int socket_flags) const;
+ int bind_internal(const char *host, int port, int socket_flags);
+ bool listen_internal();
+
+ bool routing(Request &req, Response &res, Stream &strm);
+ bool handle_file_request(Request &req, Response &res, bool head = false);
+ bool dispatch_request(Request &req, Response &res, Handlers &handlers);
+ bool dispatch_request_for_content_reader(Request &req, Response &res,
+ ContentReader content_reader,
+ HandlersForContentReader &handlers);
+
+ bool parse_request_line(const char *s, Request &req);
+ bool write_response(Stream &strm, bool last_connection, const Request &req,
+ Response &res);
+ bool write_content_with_provider(Stream &strm, const Request &req,
+ Response &res, const std::string &boundary,
+ const std::string &content_type);
+ bool read_content(Stream &strm, Request &req, Response &res);
+ bool
+ read_content_with_content_receiver(Stream &strm, Request &req, Response &res,
+ ContentReceiver receiver,
+ MultipartContentHeader multipart_header,
+ ContentReceiver multipart_receiver);
+ bool read_content_core(Stream &strm, Request &req, Response &res,
+ ContentReceiver receiver,
+ MultipartContentHeader mulitpart_header,
+ ContentReceiver multipart_receiver);
+
+ virtual bool process_and_close_socket(socket_t sock);
+
+ std::atomic<bool> is_running_;
+ std::atomic<socket_t> svr_sock_;
+ std::vector<std::pair<std::string, std::string>> base_dirs_;
+ std::map<std::string, std::string> file_extension_and_mimetype_map_;
+ Handler file_request_handler_;
+ Handlers get_handlers_;
+ Handlers post_handlers_;
+ HandlersForContentReader post_handlers_for_content_reader_;
+ Handlers put_handlers_;
+ HandlersForContentReader put_handlers_for_content_reader_;
+ Handlers patch_handlers_;
+ HandlersForContentReader patch_handlers_for_content_reader_;
+ Handlers delete_handlers_;
+ HandlersForContentReader delete_handlers_for_content_reader_;
+ Handlers options_handlers_;
+ Handler error_handler_;
+ Logger logger_;
+ Expect100ContinueHandler expect_100_continue_handler_;
+};
+
+class Client {
+public:
+ explicit Client(const std::string &host, int port = 80,
+ const std::string &client_cert_path = std::string(),
+ const std::string &client_key_path = std::string());
+
+ virtual ~Client();
+
+ virtual bool is_valid() const;
+
+ std::shared_ptr<Response> Get(const char *path);
+
+ std::shared_ptr<Response> Get(const char *path, const Headers &headers);
+
+ std::shared_ptr<Response> Get(const char *path, Progress progress);
+
+ std::shared_ptr<Response> Get(const char *path, const Headers &headers,
+ Progress progress);
+
+ std::shared_ptr<Response> Get(const char *path,
+ ContentReceiver content_receiver);
+
+ std::shared_ptr<Response> Get(const char *path, const Headers &headers,
+ ContentReceiver content_receiver);
+
+ std::shared_ptr<Response>
+ Get(const char *path, ContentReceiver content_receiver, Progress progress);
+
+ std::shared_ptr<Response> Get(const char *path, const Headers &headers,
+ ContentReceiver content_receiver,
+ Progress progress);
+
+ std::shared_ptr<Response> Get(const char *path, const Header