Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions include/crow/compression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#ifdef CROW_ENABLE_COMPRESSION
#pragma once

#ifndef ASIO_STANDALONE
#define ASIO_STANDALONE
#endif
#include <asio.hpp>
#include <memory>
#include <string>
#include <zlib.h>

Expand Down Expand Up @@ -93,6 +98,152 @@ namespace crow

return inflated_string;
}

class Compressor
{
public:
Compressor(bool reset_before_compress, int window_bits, int level):
reset_before_compress_(reset_before_compress), window_bits_(window_bits)
{
stream_ = std::make_unique<z_stream>();
stream_->zalloc = 0;
stream_->zfree = 0;
stream_->opaque = 0;

::deflateInit2(stream_.get(),
level,
Z_DEFLATED,
-window_bits_,
8,
Z_DEFAULT_STRATEGY);
}

~Compressor()
{
::deflateEnd(stream_.get());
}

bool needs_reset() const
{
return reset_before_compress_;
}

int window_bits() const
{
return window_bits_;
}

std::string compress(const std::string& src)
{
if (reset_before_compress_)
{
::deflateReset(stream_.get());
}

stream_->next_in = reinterpret_cast<uint8_t*>(const_cast<char*>(src.c_str()));
stream_->avail_in = src.size();

constexpr const uint64_t bufferSize = 8192;
asio::streambuf buffer;
do
{
asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize);

uint8_t* next_out = asio::buffer_cast<uint8_t*>(chunk);

stream_->next_out = next_out;
stream_->avail_out = bufferSize;

::deflate(stream_.get(), reset_before_compress_ ? Z_FINISH : Z_SYNC_FLUSH);

uint64_t outputSize = stream_->next_out - next_out;
buffer.commit(outputSize);
} while (stream_->avail_out == 0);

uint64_t buffer_size = buffer.size();
if (!reset_before_compress_)
{
buffer_size -= 4;
}

return std::string(asio::buffer_cast<const char*>(buffer.data()), buffer_size);
}

private:
std::unique_ptr<z_stream> stream_;

bool reset_before_compress_;
int window_bits_;
};

class Decompressor
{
public:
Decompressor(bool reset_before_decompress, int window_bits):
reset_before_decompress_(reset_before_decompress), window_bits_(window_bits)
{
stream_ = std::make_unique<z_stream>();
stream_->zalloc = 0;
stream_->zfree = 0;
stream_->opaque = 0;

::inflateInit2(stream_.get(), -window_bits_);
}

~Decompressor()
{
inflateEnd(stream_.get());
}

bool needs_reset() const
{
return reset_before_decompress_;
}

int window_bits() const
{
return window_bits_;
}

std::string decompress(std::string src)
{
if (reset_before_decompress_)
{
inflateReset(stream_.get());
}

src.push_back('\x00');
src.push_back('\x00');
src.push_back('\xff');
src.push_back('\xff');

stream_->next_in = reinterpret_cast<uint8_t*>(const_cast<char*>(src.c_str()));
stream_->avail_in = src.size();

constexpr const uint64_t bufferSize = 8192;
asio::streambuf buffer;
do
{
asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize);

uint8_t* next_out = asio::buffer_cast<uint8_t*>(chunk);

stream_->next_out = next_out;
stream_->avail_out = bufferSize;

::inflate(stream_.get(), reset_before_decompress_ ? Z_FINISH : Z_SYNC_FLUSH);
buffer.commit(stream_->next_out - next_out);
} while (stream_->avail_out == 0);

return std::string(asio::buffer_cast<const char*>(buffer.data()), buffer.size());
}

private:
std::unique_ptr<z_stream> stream_;

bool reset_before_decompress_;
int window_bits_;
};
} // namespace compression
} // namespace crow

Expand Down
68 changes: 66 additions & 2 deletions include/crow/websocket.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#pragma once
#include <array>
#include <memory>
#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
#include "crow/TinySHA1.hpp"
#include "crow/utility.h"
#include "crow/compression.h"

namespace crow
{
Expand Down Expand Up @@ -107,6 +109,17 @@ namespace crow
userdata(ud);
}

#ifdef CROW_ENABLE_COMPRESSION
std::string extensions_header = req.get_header_value("Sec-WebSocket-Extensions");
if (extensions_header.find("permessage-deflate") != std::string::npos)
{
const bool reset_compressor = extensions_header.find("server_no_context_takeover") != std::string::npos;
compressor_ = std::make_unique<compression::Compressor>(reset_compressor, compression::DEFLATE, Z_BEST_COMPRESSION);
const bool reset_decompressor = extensions_header.find("client_no_context_takeover") != std::string::npos;
decompressor_ = std::make_unique<compression::Decompressor>(reset_decompressor, compression::DEFLATE);
}
#endif

// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Sec-WebSocket-Version: 13
std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
Expand Down Expand Up @@ -186,13 +199,29 @@ namespace crow
/// Send a binary encoded message.
void send_binary(std::string msg) override
{
send_data(0x2, std::move(msg));
int opcode = 0x2;
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_)
{
opcode += 0x40;
msg = compressor_->compress(msg);
}
#endif
send_data(opcode, std::move(msg));
}

/// Send a plaintext message.
void send_text(std::string msg) override
{
send_data(0x1, std::move(msg));
int opcode = 0x1;
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_)
{
opcode += 0x40;
msg = compressor_->compress(msg);
}
#endif
send_data(opcode, std::move(msg));
}

/// Send a close signal.
Expand Down Expand Up @@ -265,6 +294,19 @@ namespace crow
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
#ifdef CROW_ENABLE_COMPRESSION
if (compressor_ && decompressor_)
{
write_buffers_.emplace_back(
"Sec-WebSocket-Extensions: permessage-deflate"
"; server_max_window_bits=" +
std::to_string(compressor_->window_bits()) +
"; client_max_window_bits=" + std::to_string(decompressor_->window_bits()) +
(compressor_->needs_reset() ? "; server_no_context_takeover" : "") +
(decompressor_->needs_reset() ? "; client_no_context_takeover" : ""));
write_buffers_.emplace_back(crlf);
}
#endif
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_)
Expand Down Expand Up @@ -528,6 +570,12 @@ namespace crow
return mini_header_ & 0x8000;
}

/// Check if payload is compressed
bool is_compressed()
{
return mini_header_ & 0x4000;
}

/// Extract the opcode from the header.
int opcode()
{
Expand Down Expand Up @@ -555,7 +603,11 @@ namespace crow
if (is_FIN())
{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand All @@ -567,7 +619,11 @@ namespace crow
if (is_FIN())
{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand All @@ -579,7 +635,11 @@ namespace crow
if (is_FIN())
{
if (message_handler_)
#ifdef CROW_ENABLE_COMPRESSION
message_handler_(*this, is_compressed() && decompressor_ ? decompressor_->decompress(message_) : message_, is_binary_);
#else
message_handler_(*this, message_, is_binary_);
#endif
message_.clear();
}
}
Expand Down Expand Up @@ -734,6 +794,10 @@ namespace crow

std::shared_ptr<void> anchor_ = std::make_shared<int>(); // Value is just for placeholding

#ifdef CROW_ENABLE_COMPRESSION
std::unique_ptr<compression::Compressor> compressor_;
std::unique_ptr<compression::Decompressor> decompressor_;
#endif
std::function<void(crow::websocket::connection&)> open_handler_;
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
Expand Down