Skip to content

File serializer.hpp

File List > serializer.hpp

Go to the documentation of this file


#pragma once
#include <algorithm>
#include <array>
#include <bit>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <functional>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <span>
#include <tuple>

#include "shared_types.h"

namespace cerializer {
    const uint16_t magicHead = 0xCAFE; 
    const uint16_t magicFoot = 0xFACE; 

    std::string hexify(auto data) {
        std::stringstream out;
        for (char c : data) {
            out << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(static_cast<unsigned char>(c));
        }
        return out.str();
    }

    template <typename T>
    constexpr auto toCharArray(const T& thing) {
        return std::bit_cast<std::array<char, sizeof(T)>>(thing);
    }

    enum ParseMode {
        START,   
        PRE_END, 
        END,     
        EMIT,    
    };

    template <typename T>
    struct is_std_array: std::false_type {};

    template <typename T, std::size_t N>
    struct is_std_array<std::array<T, N>>: std::true_type {};

    template <typename T>
    constexpr bool is_std_array_v = is_std_array<T>::value;

    template <class T>
    concept Indexable = requires { requires std::is_array_v<T> || is_std_array_v<T>; };

    template <class T, typename V = nullptr_t, typename N = int>
    concept Container = requires(T obj, V, N idx) {
        typename T::value_type;
        { obj.data() } -> std::same_as<std::conditional_t<!std::is_same_v<V, nullptr_t>, V, typename T::value_type>*>;
        { obj[idx] } -> std::same_as<std::conditional_t<!std::is_same_v<V, nullptr_t>, V, typename T::value_type>&>;
    };

    template <std::integral T, std::integral Y>
    constexpr inline std::common_type<T, Y>::type postfixAdd(T& initial, const Y& add) {
        T x = initial;
        initial = static_cast<T>(initial + add);
        return x;
    }

    template <typename... Ts>
        requires(std::is_trivially_copyable_v<Ts> && ...)
    constexpr inline std::array<char, (sizeof(Ts) + ...)> pack(const Ts&... args) {
        std::array<char, (sizeof(Ts) + ...)> dest;
        auto                                 offset = 0;
        (
            [&]() {
                if constexpr ((std::endian::native == std::endian::big) && sizeof(Ts) > 1 && !Indexable<Ts>) {
                    auto bits = std::bit_cast<std::array<char, sizeof(Ts)>>(args);
                    std::copy(bits.rbegin(), bits.rend(), dest.data() + offset);
                } else {
                    std::memcpy(dest.data() + offset, &args, sizeof(Ts));
                }
                offset += static_cast<int>(sizeof(Ts));
            }(),
            ...
        );
        return dest;
    };

    template <typename T, typename Cont>
        requires Container<Cont, char>
    constexpr inline T unpack_one(const Cont& data) {
        T value;
        std::memcpy(&value, data.data(), sizeof(T));
        if constexpr ((std::endian::native == std::endian::big) && sizeof(T) > 1 && !Indexable<T>) {
            auto bytes = std::bit_cast<std::array<char, sizeof(T)>>(value);
            std::reverse(bytes.begin(), bytes.end());
            return std::bit_cast<T>(bytes);
        }
        return value;
    }

    template <typename Dest, typename... Ts, typename Cont>
        requires(std::is_trivially_copyable_v<Ts> && ...) && Container<Cont, char>
    constexpr inline Dest unpack(const Cont& binaryData) {
        if constexpr (sizeof...(Ts) < 1) {
            return unpack_one<Dest>(binaryData);
        } else {
            size_t          offset = 0;
            std::span<char> dataView(binaryData);
            return Dest{unpack_one<Ts>(dataView.subspan(postfixAdd(offset, sizeof(Ts)), sizeof(Ts)))...};
        }
    };

    using TypeCharSpec = std::tuple<uint8_t, uint8_t, char>;

    template <typename T>
        requires std::is_trivially_copyable_v<T> && (!std::is_bounded_array_v<T>)
    constexpr TypeCharSpec formatSize() {
        return {ceil(log10(sizeof(T))), sizeof(T), 'P'};
    }

    template <typename T>
        requires std::is_bounded_array_v<T> && std::same_as<char, std::remove_all_extents_t<T>>
    constexpr TypeCharSpec formatSize() {
        return {ceil(log10(sizeof(T))), sizeof(T), 's'};
    }

    template <>
    constexpr TypeCharSpec formatSize<char>() {
        return {1, 1, 'c'};
    }

    template <>
    constexpr TypeCharSpec formatSize<signed char>() {
        return {1, 1, 'b'};
    }

    template <>
    constexpr TypeCharSpec formatSize<unsigned char>() {
        return {1, 1, 'B'};
    }

    template <>
    constexpr TypeCharSpec formatSize<bool>() {
        return {1, 1, '?'};
    }

    template <>
    constexpr TypeCharSpec formatSize<short>() {
        return {1, 2, 'h'};
    }

    template <>
    constexpr TypeCharSpec formatSize<unsigned short>() {
        return {1, 2, 'H'};
    }

    template <>
    constexpr TypeCharSpec formatSize<int>() {
        return {1, 4, 'i'};
    }

    template <>
    constexpr TypeCharSpec formatSize<unsigned int>() {
        return {1, 4, 'I'};
    }

    template <>
    constexpr TypeCharSpec formatSize<long>() {
        return {1, 4, 'l'};
    }

    template <>
    constexpr TypeCharSpec formatSize<unsigned long>() {
        return {1, 4, 'L'};
    }

    template <>
    constexpr TypeCharSpec formatSize<long long>() {
        return {1, 8, 'q'};
    }

    template <>
    constexpr TypeCharSpec formatSize<unsigned long long>() {
        return {1, 8, 'Q'};
    }

    template <>
    constexpr TypeCharSpec formatSize<float>() {
        return {1, 4, 'f'};
    }

    template <>
    constexpr TypeCharSpec formatSize<double>() {
        return {1, 8, 'd'};
    }

    template <typename... Ts>
    using RenderedFormatString = std::array<char, 1 + (std::get<0>(formatSize<Ts>()) + ...)>;

    template <typename... Ts>
        requires(std::is_trivially_copyable_v<Ts> && ...)
    constexpr inline RenderedFormatString<Ts...> renderFormat() {
        return RenderedFormatString<Ts...>{char('<'), char(std::get<2>(formatSize<Ts>()))...};
    }

    class BasePacket {
    public:
        virtual ~BasePacket() = default;
        virtual constexpr uint8_t Code() = 0;
    };

    using BasePointer = std::unique_ptr<BasePacket>;

    class MessageMaker {
    public:
        using TCreateMethod = std::function<BasePointer(const std::span<char>& binaryData)>;

        static bool Register(const uint8_t typeVal, TCreateMethod builder) {
            MessageMaker::makerMap[typeVal] = builder;
            return true;
        }

        static BasePointer Create(const uint8_t typeVal, const std::span<char>& binaryData) {
            if (auto it = MessageMaker::makerMap.find(typeVal); it != makerMap.end()) {
                return it->second(binaryData);
            }
            return nullptr;
        }

    private:
        static inline std::map<uint8_t, TCreateMethod> makerMap;
    };

    template <typename Derived, uint8_t TypeVal, typename... FieldTypes>
    class Message: public BasePacket {
    public:
        static const bool registered;

    public:
        constexpr uint8_t Code() override { return TypeVal; };

        constexpr static uint8_t Type() { return TypeVal; };

        constexpr static unsigned int Size() { return (sizeof(FieldTypes) + ...); };

        using MessageFormat = std::array<char, 4 + (std::get<0>(formatSize<FieldTypes>()) + ...)>;
        using BinaryMessage = std::array<char, (sizeof(FieldTypes) + ...) + 2 * sizeof(uint16_t) + sizeof(uint8_t)>;

        constexpr static MessageFormat Format() { return renderFormat<uint16_t, uint8_t, FieldTypes..., uint16_t>(); };

        constexpr BinaryMessage ToBinary() const {
            auto encodedData = static_cast<const Derived*>(this)->encode();
            return pack(magicHead, Type(), encodedData, magicFoot);
        };

        constexpr static Derived LoadBinary(BinaryMessage& binaryData) {
            return LoadBinary(std::span(binaryData.begin(), binaryData.end()));
        }

        constexpr static Derived LoadBinary(const std::span<char>& binaryData) {
            std::span<char> dataView(binaryData);
            uint16_t        headCheck = unpack<uint16_t>(dataView.first(sizeof(magicHead)));
            uint16_t        footCheck = unpack<uint16_t>(dataView.last(sizeof(footCheck)));

            assert(headCheck == magicHead);
            assert(footCheck == magicFoot);

            return unpack<Derived, FieldTypes...>(dataView.subspan(sizeof(uint16_t) + sizeof(uint8_t), Size()));
        }
    };

    template <typename Derived, uint8_t TypeVal, typename... FieldTypes>
    const bool Message<Derived, TypeVal, FieldTypes...>::registered = MessageMaker::Register(
        TypeVal,
        [](const std::span<char>& binaryData) -> BasePointer {
            auto obj = Derived::LoadBinary(binaryData);
            return std::make_unique<Derived>(obj);
        }
    );

    class Target: public Message<Target, 0, uint32_t, bool, uint16_t, uint16_t, uint16_t> {
    public:
        const uint32_t id;    
        const bool     valid; 
        const uint16_t x;     
        const uint16_t y;     
        const uint16_t z;     

    public:
        constexpr inline Target(
            uint32_t initial_id,
            bool     initial_valid,
            uint16_t initial_x,
            uint16_t initial_y,
            uint16_t initial_z
        ) noexcept:
            id(initial_id), valid(initial_valid), x(initial_x), y(initial_y), z(initial_z) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const { return pack(id, valid, x, y, z); }
    };

    class Config: public Message<Config, 1, float, float, uint16_t, uint16_t> {
    public:
        const float    projectile_speed; 
        const float    turret_height;    
        const uint16_t max_speed;        
        const uint16_t acceleration;     

    public:
        constexpr inline Config(
            float    initial_projectile_speed,
            float    initial_turret_height,
            uint16_t initial_max_speed,
            uint16_t initial_acceleration
        ) noexcept:
            projectile_speed(initial_projectile_speed),
            turret_height(initial_turret_height),
            max_speed(initial_max_speed),
            acceleration(initial_acceleration) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const {
            return pack(projectile_speed, turret_height, max_speed, acceleration);
        }
    };

    class SetTargetSourceMessage: public Message<SetTargetSourceMessage, 2, TargetSource> {
    public:
        const TargetSource source; 

    public:
        constexpr inline SetTargetSourceMessage(TargetSource initial_source) noexcept: source(initial_source) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const { return pack(source); }
    };

    class StaticTargetMessage: public Message<StaticTargetMessage, 3, uint16_t, uint16_t, uint16_t> {
    public:
        const uint16_t x; 
        const uint16_t y; 
        const uint16_t z; 

    public:
        constexpr inline StaticTargetMessage(uint16_t initial_x, uint16_t initial_y, uint16_t initial_z) noexcept:
            x(initial_x), y(initial_y), z(initial_z) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const { return pack(x, y, z); }
    };

    class SetStrategyMessage: public Message<SetStrategyMessage, 4, TurretStrategy> {
    public:
        const TurretStrategy strategy;

    public:
        constexpr inline SetStrategyMessage(TurretStrategy initial_strategy) noexcept: strategy(initial_strategy) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const { return pack(strategy); }
    };

    class SetStanceMessage: public Message<SetStanceMessage, 5, TurretStance> {
    public:
        const TurretStance stance;

    public:
        constexpr inline SetStanceMessage(TurretStance initial_stance) noexcept: stance(initial_stance) {
            assert(registered);
        }

        constexpr std::array<char, Size()> encode() const { return pack(stance); }
    };

    template <typename T>
    concept IOAble = requires(T io, char* buf, const char* cbuf, size_t count) {
        { io.readsome(buf, count) } -> std::convertible_to<size_t>;
        { io.write(cbuf, count) };
        { io.good() } -> std::convertible_to<bool>;
    };

    template <typename T>
        requires IOAble<T>
    class Serializer {
    private:
        T& output;

    public:
        Serializer(T& outputStream): output(outputStream) {};

        template <typename M, uint8_t U, typename... Fs>
        void Write(const Message<M, U, Fs...>& message) {
            auto binaryMessage = message.ToBinary();
            output.write(binaryMessage.data(), binaryMessage.size());
        }
    };

    template <typename T>
        requires IOAble<T>
    class Deserializer {
    protected:
        T&                    input;
        std::array<char, 128> buf;
        char*                 offset = buf.begin();
        char*                 end_offset = buf.begin();

        const std::array<char, 2> header_bytes = toCharArray(magicHead);
        const std::array<char, 2> footer_bytes = toCharArray(magicFoot);
        std::size_t               read_size = header_bytes.size();
        std::array<char, 2>       token = header_bytes;
        ParseMode                 state = START;
        ParseMode                 success = PRE_END;
        ParseMode                 fail = START;

        constexpr auto findToken(
            const std::span<char>& buffer,
            const std::span<char>& value,
            const ParseMode&       fail_state,
            const ParseMode&       success_state
        ) {
            auto next_state = fail_state;
            auto next_size = value.size();
            auto next_offset = buffer.begin();

            if (buffer.size() < value.size()) {
                next_size = value.size();
            } else if (buffer.size() >= value.size()) {
                auto index = std::search(buffer.begin(), buffer.end(), value.begin(), value.end());
                if (index == buffer.end()) {
                    next_offset =
                        buffer.end() - static_cast<typename std::span<char>::difference_type>(value.size() - 1);
                } else {
                    next_offset = index;
                    next_state = success_state;
                }
            }

            return std::tuple{next_state, next_size, next_offset};
        }

    public:
        Deserializer(T& readStream): input(readStream) {};

        template <std::derived_from<BasePacket> Type>
        void ParseStream(std::function<void(std::unique_ptr<Type>&)> callback) {
            size_t read = 0;

            read = static_cast<size_t>(input.readsome(end_offset, static_cast<std::streamsize>(read_size)));
            end_offset += read;

            auto result_1 = findToken(std::span(offset, end_offset), std::span(token), fail, success);
            state = std::get<0>(result_1);
            read_size = std::get<1>(result_1);
            offset = std::get<2>(result_1).base();

            while (input.good() && (read > 0 || state == EMIT)) {
                switch (state) {
                case START:
                    token = header_bytes;
                    success = PRE_END;
                    fail = START;
                    if (offset > buf.begin()) {
                        auto shift = offset - buf.begin();
                        std::copy(offset, end_offset, buf.begin());
                        end_offset -= shift;
                        offset = buf.begin();
                    }
                    break;
                case PRE_END:
                    token = footer_bytes;
                    success = EMIT;
                    fail = END;
                    if (offset > buf.begin()) {
                        auto shift = offset - buf.begin();
                        std::copy(offset, end_offset, buf.begin());
                        end_offset -= shift;
                        offset = buf.begin();
                    }
                    break;
                case END:
                    token = footer_bytes;
                    success = EMIT;
                    fail = END;
                    break;
                case EMIT: {
                    uint8_t typeCode = static_cast<uint8_t>(buf.begin()[2]);
                    char*   message_end_ptr = offset + sizeof(footer_bytes);
                    auto    found = MessageMaker::Create(typeCode, std::span(buf.begin(), message_end_ptr));

                    callback(found);

                    // Shift the remaining buffer content to the beginning
                    auto remaining_size = end_offset - message_end_ptr;
                    std::copy(message_end_ptr, end_offset, buf.begin());

                    // Reset pointers and state for the next message
                    offset = buf.begin();
                    end_offset = buf.begin() + remaining_size;
                    state = START;
                    token = header_bytes;
                    success = PRE_END;
                    fail = START;
                    continue;
                }
                }

                auto result_2 = findToken(std::span<char>(offset, end_offset), std::span(token), fail, success);
                state = std::get<0>(result_2);
                read_size = std::get<1>(result_2);
                offset = std::get<2>(result_2).base();

                read = static_cast<size_t>(input.readsome(end_offset, static_cast<std::streamsize>(read_size)));
                end_offset += read;
            }
        }
    };

    template <typename T>
        requires IOAble<T>
    class StreamHandler {
    private:
        Serializer<T>   serializer;
        Deserializer<T> deserializer;

    public:
        StreamHandler(T& stream): serializer(stream), deserializer(stream) {};

        template <typename M, uint8_t U, typename... Fs>
        void Write(const Message<M, U, Fs...>& message) {
            serializer.Write(message);
        }

        template <std::derived_from<BasePacket> Type>
        void ParseStream(std::function<void(std::unique_ptr<Type>&)> callback) {
            deserializer.ParseStream(callback);
        }
    };
} // namespace cerializer