//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H #define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/Orc/RPCSerialization.h" #include "llvm/Support/Endian.h" #include "llvm/Support/Error.h" #include <cstdint> #include <mutex> #include <string> #include <type_traits> namespace llvm { namespace orc { namespace rpc { /// Interface for byte-streams to be used with RPC. class RawByteChannel { public: virtual ~RawByteChannel() = default; /// Read Size bytes from the stream into *Dst. virtual Error readBytes(char *Dst, unsigned Size) = 0; /// Read size bytes from *Src and append them to the stream. virtual Error appendBytes(const char *Src, unsigned Size) = 0; /// Flush the stream if possible. virtual Error send() = 0; /// Notify the channel that we're starting a message send. /// Locks the channel for writing. template <typename FunctionIdT, typename SequenceIdT> Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { writeLock.lock(); if (auto Err = serializeSeq(*this, FnId, SeqNo)) { writeLock.unlock(); return Err; } return Error::success(); } /// Notify the channel that we're ending a message send. /// Unlocks the channel for writing. Error endSendMessage() { writeLock.unlock(); return Error::success(); } /// Notify the channel that we're starting a message receive. /// Locks the channel for reading. template <typename FunctionIdT, typename SequenceNumberT> Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { readLock.lock(); if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { readLock.unlock(); return Err; } return Error::success(); } /// Notify the channel that we're ending a message receive. /// Unlocks the channel for reading. Error endReceiveMessage() { readLock.unlock(); return Error::success(); } /// Get the lock for stream reading. std::mutex &getReadLock() { return readLock; } /// Get the lock for stream writing. std::mutex &getWriteLock() { return writeLock; } private: std::mutex readLock, writeLock; }; template <typename ChannelT, typename T> class SerializationTraits< ChannelT, T, T, typename std::enable_if< std::is_base_of<RawByteChannel, ChannelT>::value && (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || std::is_same<T, char>::value)>::type> { public: static Error serialize(ChannelT &C, T V) { support::endian::byte_swap<T, support::big>(V); return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); }; static Error deserialize(ChannelT &C, T &V) { if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) return Err; support::endian::byte_swap<T, support::big>(V); return Error::success(); }; }; template <typename ChannelT> class SerializationTraits<ChannelT, bool, bool, typename std::enable_if<std::is_base_of< RawByteChannel, ChannelT>::value>::type> { public: static Error serialize(ChannelT &C, bool V) { uint8_t Tmp = V ? 1 : 0; if (auto Err = C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) return Err; return Error::success(); } static Error deserialize(ChannelT &C, bool &V) { uint8_t Tmp = 0; if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) return Err; V = Tmp != 0; return Error::success(); } }; template <typename ChannelT> class SerializationTraits<ChannelT, std::string, StringRef, typename std::enable_if<std::is_base_of< RawByteChannel, ChannelT>::value>::type> { public: /// RPC channel serialization for std::strings. static Error serialize(RawByteChannel &C, StringRef S) { if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) return Err; return C.appendBytes((const char *)S.data(), S.size()); } }; template <typename ChannelT, typename T> class SerializationTraits<ChannelT, std::string, T, typename std::enable_if< std::is_base_of<RawByteChannel, ChannelT>::value && (std::is_same<T, const char*>::value || std::is_same<T, char*>::value)>::type> { public: static Error serialize(RawByteChannel &C, const char *S) { return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, S); } }; template <typename ChannelT> class SerializationTraits<ChannelT, std::string, std::string, typename std::enable_if<std::is_base_of< RawByteChannel, ChannelT>::value>::type> { public: /// RPC channel serialization for std::strings. static Error serialize(RawByteChannel &C, const std::string &S) { return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, S); } /// RPC channel deserialization for std::strings. static Error deserialize(RawByteChannel &C, std::string &S) { uint64_t Count = 0; if (auto Err = deserializeSeq(C, Count)) return Err; S.resize(Count); return C.readBytes(&S[0], Count); } }; } // end namespace rpc } // end namespace orc } // end namespace llvm #endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H