diff options
author | Kae <80987908+Novaenia@users.noreply.github.com> | 2023-06-20 14:33:09 +1000 |
---|---|---|
committer | Kae <80987908+Novaenia@users.noreply.github.com> | 2023-06-20 14:33:09 +1000 |
commit | 6352e8e3196f78388b6c771073f9e03eaa612673 (patch) | |
tree | e23772f79a7fbc41bc9108951e9e136857484bf4 /source/core | |
parent | 6741a057e5639280d85d0f88ba26f000baa58f61 (diff) |
everything everywhere
all at once
Diffstat (limited to 'source/core')
181 files changed, 50127 insertions, 0 deletions
diff --git a/source/core/CMakeLists.txt b/source/core/CMakeLists.txt new file mode 100644 index 0000000..d0a53b1 --- /dev/null +++ b/source/core/CMakeLists.txt @@ -0,0 +1,201 @@ +INCLUDE_DIRECTORIES ( + ${STAR_EXTERN_INCLUDES} + ${STAR_CORE_INCLUDES} + ) + +SET (star_core_HEADERS + StarAStar.hpp + StarAlgorithm.hpp + StarArray.hpp + StarAtomicSharedPtr.hpp + StarAudio.hpp + StarBTree.hpp + StarBTreeDatabase.hpp + StarBiMap.hpp + StarBlockAllocator.hpp + StarBuffer.hpp + StarByteArray.hpp + StarBytes.hpp + StarCasting.hpp + StarColor.hpp + StarCompression.hpp + StarConfig.hpp + StarDataStream.hpp + StarDataStreamDevices.hpp + StarDataStreamExtra.hpp + StarDynamicLib.hpp + StarEither.hpp + StarEncode.hpp + StarException.hpp + StarFile.hpp + StarFlatHashMap.hpp + StarFlatHashSet.hpp + StarFont.hpp + StarFormat.hpp + StarHash.hpp + StarHostAddress.hpp + StarIODevice.hpp + StarIdMap.hpp + StarImage.hpp + StarImageProcessing.hpp + StarInterpolation.hpp + StarRefPtr.hpp + StarIterator.hpp + StarJson.hpp + StarJsonBuilder.hpp + StarJsonExtra.hpp + StarJsonParser.hpp + StarJsonPath.hpp + StarJsonPatch.hpp + StarJsonRpc.hpp + StarFormattedJson.hpp + StarLexicalCast.hpp + StarLine.hpp + StarList.hpp + StarListener.hpp + StarLockFile.hpp + StarLogging.hpp + StarLruCache.hpp + StarLua.hpp + StarLuaConverters.hpp + StarMap.hpp + StarMathCommon.hpp + StarMatrix3.hpp + StarMaybe.hpp + StarMemory.hpp + StarMultiArray.hpp + StarMultiArrayInterpolator.hpp + StarMultiTable.hpp + StarNetElement.hpp + StarNetElementBasicFields.hpp + StarNetElementContainers.hpp + StarNetElementDynamicGroup.hpp + StarNetElementFloatFields.hpp + StarNetElementGroup.hpp + StarNetElementSignal.hpp + StarNetElementSyncGroup.hpp + StarNetElementSystem.hpp + StarNetElementTop.hpp + StarNetImpl.hpp + StarObserverStream.hpp + StarOptionParser.hpp + StarOrderedMap.hpp + StarOrderedSet.hpp + StarParametricFunction.hpp + StarPeriodic.hpp + StarPeriodicFunction.hpp + StarPerlin.hpp + StarPoly.hpp + StarPythonic.hpp + StarRandom.hpp + StarRandomPoint.hpp + StarRect.hpp + StarRpcPromise.hpp + StarSectorArray2D.hpp + StarSecureRandom.hpp + StarSet.hpp + StarSha256.hpp + StarShellParser.hpp + StarSignalHandler.hpp + StarSocket.hpp + StarSpatialHash2D.hpp + StarSpline.hpp + StarStaticRandom.hpp + StarStaticVector.hpp + StarString.hpp + StarStrongTypedef.hpp + StarTcp.hpp + StarThread.hpp + StarTickRateMonitor.hpp + StarTime.hpp + StarTtlCache.hpp + StarUdp.hpp + StarUnicode.hpp + StarUuid.hpp + StarVector.hpp + StarVlqEncoding.hpp + StarWeightedPool.hpp + StarWorkerPool.hpp + StarXXHash.hpp + ) + +SET (star_core_SOURCES + StarAudio.cpp + StarBTreeDatabase.cpp + StarBuffer.cpp + StarByteArray.cpp + StarColor.cpp + StarCompression.cpp + StarDataStream.cpp + StarDataStreamDevices.cpp + StarEncode.cpp + StarFile.cpp + StarFont.cpp + StarHostAddress.cpp + StarIODevice.cpp + StarImage.cpp + StarImageProcessing.cpp + StarJson.cpp + StarJsonBuilder.cpp + StarJsonExtra.cpp + StarJsonPath.cpp + StarJsonPatch.cpp + StarJsonRpc.cpp + StarFormattedJson.cpp + StarListener.cpp + StarLogging.cpp + StarLua.cpp + StarLuaConverters.cpp + StarMemory.cpp + StarNetElement.cpp + StarNetElementBasicFields.cpp + StarNetElementGroup.cpp + StarNetElementSyncGroup.cpp + StarOptionParser.cpp + StarPerlin.cpp + StarRandom.cpp + StarSha256.cpp + StarShellParser.cpp + StarSocket.cpp + StarString.cpp + StarTcp.cpp + StarThread.cpp + StarTime.cpp + StarTickRateMonitor.cpp + StarUdp.cpp + StarUnicode.cpp + StarUuid.cpp + StarWorkerPool.cpp + ) + +IF (STAR_SYSTEM_FAMILY_UNIX) + SET (star_core_SOURCES ${star_core_SOURCES} + StarDynamicLib_unix.cpp + StarException_unix.cpp + StarFile_unix.cpp + StarLockFile_unix.cpp + StarSecureRandom_unix.cpp + StarSignalHandler_unix.cpp + StarThread_unix.cpp + StarTime_unix.cpp + ) +ELSEIF (STAR_SYSTEM_FAMILY_WINDOWS) + SET (star_core_HEADERS ${star_core_HEADERS} + StarString_windows.hpp + ) + + SET (star_core_SOURCES ${star_core_SOURCES} + StarDynamicLib_windows.cpp + StarFile_windows.cpp + StarLockFile_windows.cpp + StarSignalHandler_windows.cpp + StarString_windows.cpp + StarThread_windows.cpp + StarTime_windows.cpp + StarException_windows.cpp + StarSecureRandom_windows.cpp + ) + +ENDIF () + +ADD_LIBRARY (star_core OBJECT ${star_core_SOURCES} ${star_core_HEADERS}) diff --git a/source/core/StarAStar.hpp b/source/core/StarAStar.hpp new file mode 100644 index 0000000..52dc5d4 --- /dev/null +++ b/source/core/StarAStar.hpp @@ -0,0 +1,276 @@ +#ifndef STAR_A_STAR_HPP +#define STAR_A_STAR_HPP + +#include <queue> + +#include "StarList.hpp" +#include "StarMap.hpp" +#include "StarSet.hpp" +#include "StarLexicalCast.hpp" +#include "StarMathCommon.hpp" +#include "StarBlockAllocator.hpp" + +namespace Star { +namespace AStar { + + struct Score { + Score(); + + double gScore; + double hScore; + double fScore; + }; + + // 'Edge' should be implemented as a class with public fields compatible with + // these: + // double cost; + // Node source; + // Node target; + + template <class Edge> + using Path = List<Edge>; + + template <class Edge, class Node> + class Search { + public: + typedef function<double(Node, Node)> HeuristicFunction; + typedef function<void(Node, List<Edge>& neighbors)> NeighborFunction; + typedef function<bool(Node)> GoalFunction; + typedef function<bool(Node, Node)> CompareFunction; + typedef function<bool(Edge)> ValidateEndFunction; + + Search(HeuristicFunction heuristicCost, + NeighborFunction getAdjacent, + GoalFunction goalReached, + bool returnBestIfFailed = false, + // In returnBestIfFailed mode, validateEnd checks the end of the path + // is valid, e.g. not floating in the air. + Maybe<ValidateEndFunction> validateEnd = {}, + Maybe<double> maxFScore = {}, + Maybe<unsigned> maxNodesToSearch = {}); + + // Start a new exploration, resets result if it was found before. + void start(Node startNode, Node goalNode); + // Explore the given number of nodes in the search space. If + // maxNodesToSearch is reached, or the search space is exhausted, will + // return + // false to signal failure. On success, will return true. If the given + // maxExploreNodes is exhausted before success or failure, will return + // nothing. + Maybe<bool> explore(Maybe<unsigned> maxExploreNodes = {}); + // Returns the result if it was found. + Maybe<Path<Edge>> const& result() const; + + // Convenience, equivalent to calling start, then explore({}) and returns + // result() + Maybe<Path<Edge>> const& findPath(Node startNode, Node goalNode); + + private: + struct ScoredNode { + bool operator<(ScoredNode const& other) const { + return score.fScore > other.score.fScore; + } + + Score score; + Node node; + }; + + struct NodeMeta { + Score score; + Maybe<Edge> cameFrom; + }; + + Path<Edge> reconstructPath(Node currentNode); + + HeuristicFunction m_heuristicCost; + NeighborFunction m_getAdjacent; + GoalFunction m_goalReached; + bool m_returnBestIfFailed; + Maybe<ValidateEndFunction> m_validateEnd; + Maybe<double> m_maxFScore; + Maybe<unsigned> m_maxNodesToSearch; + + Node m_goal; + Map<Node, NodeMeta, std::less<Node>, BlockAllocator<pair<Node const, NodeMeta>, 1024>> m_nodeMeta; + std::priority_queue<ScoredNode> m_openQueue; + Set<Node, std::less<Node>, BlockAllocator<Node, 1024>> m_openSet; + Set<Node, std::less<Node>, BlockAllocator<Node, 1024>> m_closedSet; + Maybe<ScoredNode> m_earlyExploration; + + bool m_finished; + Maybe<Path<Edge>> m_result; + }; + + inline Score::Score() : gScore(highest<double>()), hScore(0), fScore(highest<double>()) {} + + template <class Edge, class Node> + Search<Edge, Node>::Search(HeuristicFunction heuristicCost, + NeighborFunction getAdjacent, + GoalFunction goalReached, + bool returnBestIfFailed, + Maybe<ValidateEndFunction> validateEnd, + Maybe<double> maxFScore, + Maybe<unsigned> maxNodesToSearch) + : m_heuristicCost(heuristicCost), + m_getAdjacent(getAdjacent), + m_goalReached(goalReached), + m_returnBestIfFailed(returnBestIfFailed), + m_validateEnd(validateEnd), + m_maxFScore(maxFScore), + m_maxNodesToSearch(maxNodesToSearch) {} + + template <class Edge, class Node> + void Search<Edge, Node>::start(Node startNode, Node goalNode) { + m_goal = move(goalNode); + m_nodeMeta.clear(); + m_openQueue = std::priority_queue<ScoredNode>(); + m_openSet.clear(); + m_closedSet.clear(); + m_earlyExploration = {}; + m_finished = false; + m_result.reset(); + + Score startScore; + startScore.gScore = 0; + startScore.hScore = m_heuristicCost(startNode, m_goal); + startScore.fScore = startScore.hScore; + m_nodeMeta[startNode].score = startScore; + + m_openSet.insert(startNode); + m_openQueue.push(ScoredNode{startScore, move(startNode)}); + } + + template <class Edge, class Node> + Maybe<bool> Search<Edge, Node>::explore(Maybe<unsigned> maxExploreNodes) { + if (m_finished) + return m_result.isValid(); + + List<Edge> neighbors; + while (true) { + if ((m_maxNodesToSearch && m_closedSet.size() > *m_maxNodesToSearch) + || (m_openQueue.empty() && !m_earlyExploration)) { + m_finished = true; + // Search failed. Either return the path to the closest node to the + // target, + // or return nothing. + if (m_returnBestIfFailed) { + double bestScore = highest<double>(); + Maybe<Node> bestNode; + for (Node node : m_closedSet) { + NodeMeta const& nodeMeta = m_nodeMeta[node]; + if (m_validateEnd && nodeMeta.cameFrom && !(*m_validateEnd)(*nodeMeta.cameFrom)) + continue; + if (nodeMeta.score.hScore < bestScore) { + bestScore = nodeMeta.score.hScore; + bestNode = node; + } + } + + if (bestNode) + m_result = reconstructPath(*bestNode); + } + + return false; + } + + if (maxExploreNodes) { + if (*maxExploreNodes == 0) + return {}; + --*maxExploreNodes; + } + + ScoredNode currentScoredNode; + if (m_earlyExploration) { + currentScoredNode = m_earlyExploration.take(); + } else { + currentScoredNode = m_openQueue.top(); + m_openQueue.pop(); + if (!m_openSet.remove(currentScoredNode.node)) + // Duplicate entry in the queue due to this node's score being + // updated. + // Just ignore this node; we've already searched it. + continue; + } + + Node const& current = currentScoredNode.node; + Score const& currentScore = currentScoredNode.score; + + if (m_goalReached(current)) { + m_finished = true; + m_result = reconstructPath(current); + return true; + } + + m_closedSet.insert(current); + + neighbors.clear(); + m_getAdjacent(current, neighbors); + + for (Edge const& edge : neighbors) { + if (m_closedSet.find(edge.target) != m_closedSet.end()) + // We've already visited this node. + continue; + + double newGScore = currentScore.gScore + edge.cost; + NodeMeta& targetMeta = m_nodeMeta[edge.target]; + Score& targetScore = targetMeta.score; + if (m_openSet.find(edge.target) == m_openSet.end() || newGScore < targetScore.gScore) { + targetMeta.cameFrom = edge; + targetScore.gScore = newGScore; + targetScore.hScore = m_heuristicCost(edge.target, m_goal); + targetScore.fScore = targetScore.gScore + targetScore.hScore; + + if (m_maxFScore && targetScore.fScore > *m_maxFScore) + continue; + + // Early exploration optimization - no need to add things to the + // openQueue/openSet + // if they're at least as good as the current node. + if (targetScore.fScore <= currentScore.fScore) { + if (m_earlyExploration.isNothing()) { + m_earlyExploration = ScoredNode{targetScore, edge.target}; + continue; + } else if (m_earlyExploration->score.fScore > targetScore.fScore) { + m_openSet.insert(m_earlyExploration->node); + m_openQueue.push(*m_earlyExploration); + m_earlyExploration = ScoredNode{targetScore, edge.target}; + continue; + } + } + m_openSet.insert(edge.target); + m_openQueue.push(ScoredNode{targetScore, edge.target}); + } + } + } + } + + template <class Edge, class Node> + Maybe<Path<Edge>> const& Search<Edge, Node>::result() const { + return m_result; + } + + template <class Edge, class Node> + Maybe<Path<Edge>> const& Search<Edge, Node>::findPath(Node startNode, Node goalNode) { + start(move(startNode), move(goalNode)); + explore(); + return result(); + } + + template <class Edge, class Node> + Path<Edge> Search<Edge, Node>::reconstructPath(Node currentNode) { + Path<Edge> res; // this will be backwards, we reverse it before returning it. + while (m_nodeMeta.find(currentNode) != m_nodeMeta.end()) { + Maybe<Edge> currentEdge = m_nodeMeta[currentNode].cameFrom; + if (currentEdge.isNothing()) + break; + res.append(*currentEdge); + currentNode = currentEdge->source; + } + std::reverse(res.begin(), res.end()); + return res; + } +} + +} + +#endif diff --git a/source/core/StarAlgorithm.hpp b/source/core/StarAlgorithm.hpp new file mode 100644 index 0000000..331f653 --- /dev/null +++ b/source/core/StarAlgorithm.hpp @@ -0,0 +1,667 @@ +#ifndef STAR_ALGORITHM_HPP +#define STAR_ALGORITHM_HPP + +#include <type_traits> +#include <vector> +#include <iterator> + +#include "StarException.hpp" + +namespace Star { + +// Function that does nothing and takes any number of arguments +template <typename... T> +void nothing(T&&...) {} + +// Functional constructor call / casting. +template <typename ToType> +struct construct { + template <typename... FromTypes> + ToType operator()(FromTypes&&... fromTypes) const { + return ToType(forward<FromTypes>(fromTypes)...); + } +}; + +struct identity { + template <typename U> + constexpr decltype(auto) operator()(U&& v) const { + return std::forward<U>(v); + } +}; + +template <typename Func> +struct SwallowReturn { + template <typename... T> + void operator()(T&&... args) { + func(forward<T>(args)...); + } + + Func func; +}; + +template <typename Func> +SwallowReturn<Func> swallow(Func f) { + return SwallowReturn<Func>{move(f)}; +} + +struct Empty { + bool operator==(Empty const) const { + return true; + } + + bool operator<(Empty const) const { + return false; + } +}; + +// Compose arbitrary functions +template <typename FirstFunction, typename SecondFunction> +struct FunctionComposer { + FirstFunction f1; + SecondFunction f2; + + template <typename... T> + decltype(auto) operator()(T&&... args) { + return f1(f2(forward<T>(args)...)); + } +}; + +template <typename FirstFunction, typename SecondFunction> +decltype(auto) compose(FirstFunction&& firstFunction, SecondFunction&& secondFunction) { + return FunctionComposer<FirstFunction, SecondFunction>{move(forward<FirstFunction>(firstFunction)), move(forward<SecondFunction>(secondFunction))}; +} + +template <typename FirstFunction, typename SecondFunction, typename ThirdFunction, typename... RestFunctions> +decltype(auto) compose(FirstFunction firstFunction, SecondFunction secondFunction, ThirdFunction thirdFunction, RestFunctions... restFunctions) { + return compose(forward<FirstFunction>(firstFunction), compose(forward<SecondFunction>(secondFunction), compose(forward<ThirdFunction>(thirdFunction), forward<RestFunctions>(restFunctions)...))); +} + +template <typename Container, typename Value, typename Function> +Value fold(Container const& l, Value v, Function f) { + auto i = l.begin(); + auto e = l.end(); + while (i != e) { + v = f(v, *i); + ++i; + } + return v; +} + +// Like fold, but returns default value when container is empty. +template <typename Container, typename Function> +typename Container::value_type fold1(Container const& l, Function f) { + typename Container::value_type res = {}; + typename Container::const_iterator i = l.begin(); + typename Container::const_iterator e = l.end(); + + if (i == e) + return res; + + res = *i; + ++i; + while (i != e) { + res = f(res, *i); + ++i; + } + return res; +} + +// Return intersection of sorted containers. +template <typename Container> +Container intersect(Container const& a, Container const& b) { + Container r; + std::set_intersection(a.begin(), a.end(), b.begin(), b.end(), std::inserter(r, r.end())); + return r; +} + +template <typename MapType1, typename MapType2> +bool mapMerge(MapType1& targetMap, MapType2 const& sourceMap, bool overwrite = false) { + bool noCommonKeys = true; + for (auto i = sourceMap.begin(); i != sourceMap.end(); ++i) { + auto res = targetMap.insert(*i); + if (!res.second) { + noCommonKeys = false; + if (overwrite) + res.first->second = i->second; + } + } + return noCommonKeys; +} + +template <typename MapType1, typename MapType2> +bool mapsEqual(MapType1 const& m1, MapType2 const& m2) { + if (&m1 == &m2) + return true; + + if (m1.size() != m2.size()) + return false; + + for (auto const& m1pair : m1) { + auto m2it = m2.find(m1pair.first); + if (m2it == m2.end() || !(m2it->second == m1pair.second)) + return false; + } + + return true; +} + +template <typename Container, typename Filter> +void filter(Container& container, Filter&& filter) { + auto p = std::begin(container); + while (p != std::end(container)) { + if (!filter(*p)) + p = container.erase(p); + else + ++p; + } +} + +template <typename OutContainer, typename InContainer, typename Filter> +OutContainer filtered(InContainer const& input, Filter&& filter) { + OutContainer out; + auto p = std::begin(input); + while (p != std::end(input)) { + if (filter(*p)) + out.insert(out.end(), *p); + ++p; + } + return out; +} + +template <typename Container, typename Cond> +void eraseWhere(Container& container, Cond&& cond) { + auto p = std::begin(container); + while (p != std::end(container)) { + if (cond(*p)) + p = container.erase(p); + else + ++p; + } +} + +template <typename Container, typename Compare> +void sort(Container& c, Compare comp) { + std::sort(c.begin(), c.end(), comp); +} + +template <typename Container, typename Compare> +void stableSort(Container& c, Compare comp) { + std::stable_sort(c.begin(), c.end(), comp); +} + +template <typename Container> +void sort(Container& c) { + std::sort(c.begin(), c.end(), std::less<typename Container::value_type>()); +} + +template <typename Container> +void stableSort(Container& c) { + std::stable_sort(c.begin(), c.end(), std::less<typename Container::value_type>()); +} + +template <typename Container, typename Compare> +Container sorted(Container const& c, Compare comp) { + auto c2 = c; + sort(c2, comp); + return c2; +} + +template <typename Container, typename Compare> +Container stableSorted(Container const& c, Compare comp) { + auto c2 = c; + sort(c2, comp); + return c2; +} + +template <typename Container> +Container sorted(Container const& c) { + auto c2 = c; + sort(c2); + return c2; +} + +template <typename Container> +Container stableSorted(Container const& c) { + auto c2 = c; + sort(c2); + return c2; +} + +// Sort a container by the output of a computed value. The computed value is +// only computed *once* per item in the container, which is useful both for +// when the computed value is costly, and to avoid sorting instability with +// floating point values. Container must have size() and operator[], and also +// must be constructable with Container(size_t). +template <typename Container, typename Getter> +void sortByComputedValue(Container& container, Getter&& valueGetter, bool stable = false) { + typedef typename Container::value_type ContainerValue; + typedef decltype(valueGetter(ContainerValue())) ComputedValue; + typedef std::pair<ComputedValue, size_t> ComputedPair; + + size_t containerSize = container.size(); + + if (containerSize <= 1) + return; + + std::vector<ComputedPair> work(containerSize); + for (size_t i = 0; i < containerSize; ++i) + work[i] = {valueGetter(container[i]), i}; + + auto compare = [](ComputedPair const& a, ComputedPair const& b) { return a.first < b.first; }; + + // Sort the comptued values and the associated indexes + if (stable) + stableSort(work, compare); + else + sort(work, compare); + + Container result(containerSize); + for (size_t i = 0; i < containerSize; ++i) + swap(result[i], container[work[i].second]); + + swap(container, result); +} + +template <typename Container, typename Getter> +void stableSortByComputedValue(Container& container, Getter&& valueGetter) { + return sortByComputedValue(container, forward<Getter>(valueGetter), true); +} + +template <typename Container> +void shuffle(Container& c) { + std::random_shuffle(c.begin(), c.end()); +} + +template <typename Container> +void reverse(Container& c) { + std::reverse(c.begin(), c.end()); +} + +template <typename Container> +Container reverseCopy(Container c) { + reverse(c); + return c; +} + +template <typename T> +T copy(T c) { + return c; +} + +template <typename Container> +typename Container::value_type sum(Container const& cont) { + return fold1(cont, std::plus<typename Container::value_type>()); +} + +template <typename Container> +typename Container::value_type product(Container const& cont) { + return fold1(cont, std::multiplies<typename Container::value_type>()); +} + +template <typename OutContainer, typename InContainer, typename Function> +void transformInto(OutContainer& outContainer, InContainer&& inContainer, Function&& function) { + for (auto&& elem : inContainer) { + if (std::is_rvalue_reference<InContainer&&>::value) + outContainer.insert(outContainer.end(), function(move(elem))); + else + outContainer.insert(outContainer.end(), function(elem)); + } +} + +template <typename OutContainer, typename InContainer, typename Function> +OutContainer transform(InContainer&& container, Function&& function) { + OutContainer res; + transformInto(res, forward<InContainer>(container), forward<Function>(function)); + return res; +} + +template <typename OutputContainer, typename Function, typename Container1, typename Container2> +OutputContainer zipWith(Function&& function, Container1 const& cont1, Container2 const& cont2) { + auto it1 = cont1.begin(); + auto it2 = cont2.begin(); + + OutputContainer out; + while (it1 != cont1.end() && it2 != cont2.end()) { + out.insert(out.end(), function(*it1, *it2)); + ++it1; + ++it2; + } + + return out; +} + +// Moves the given value and into an rvalue. Works whether or not the type has +// a valid move constructor or not. Always leaves the given value in its +// default constructed state. +template <typename T> +T take(T& t) { + T t2 = move(t); + t = T(); + return t2; +} + +template <typename Container1, typename Container2> +bool containersEqual(Container1 const& cont1, Container2 const& cont2) { + if (cont1.size() != cont2.size()) + return false; + else + return std::equal(cont1.begin(), cont1.end(), cont2.begin()); +} + +// Wraps a unary function to produce an output iterator +template <typename UnaryFunction> +class FunctionOutputIterator { +public: + typedef std::output_iterator_tag iterator_category; + typedef void value_type; + typedef void difference_type; + typedef void pointer; + typedef void reference; + + class OutputProxy { + public: + OutputProxy(UnaryFunction& f) + : m_function(f) {} + + template <typename T> + OutputProxy& operator=(T&& value) { + m_function(forward<T>(value)); + return *this; + } + + private: + UnaryFunction& m_function; + }; + + explicit FunctionOutputIterator(UnaryFunction f = UnaryFunction()) + : m_function(move(f)) {} + + OutputProxy operator*() { + return OutputProxy(m_function); + } + + FunctionOutputIterator& operator++() { + return *this; + } + + FunctionOutputIterator operator++(int) { + return *this; + } + +private: + UnaryFunction m_function; +}; + +template <typename UnaryFunction> +FunctionOutputIterator<UnaryFunction> makeFunctionOutputIterator(UnaryFunction f) { + return FunctionOutputIterator<UnaryFunction>(move(f)); +} + +// Wraps a nullary function to produce an input iterator +template <typename NullaryFunction> +class FunctionInputIterator { +public: + typedef std::output_iterator_tag iterator_category; + typedef void value_type; + typedef void difference_type; + typedef void pointer; + typedef void reference; + + typedef typename std::result_of<NullaryFunction()>::type FunctionReturnType; + + explicit FunctionInputIterator(NullaryFunction f = {}) + : m_function(move(f)) {} + + FunctionReturnType operator*() { + return m_function(); + } + + FunctionInputIterator& operator++() { + return *this; + } + + FunctionInputIterator operator++(int) { + return *this; + } + +private: + NullaryFunction m_function; +}; + +template <typename NullaryFunction> +FunctionInputIterator<NullaryFunction> makeFunctionInputIterator(NullaryFunction f) { + return FunctionInputIterator<NullaryFunction>(move(f)); +} + +template <typename Iterable> +struct ReverseWrapper { +private: + Iterable& m_iterable; + +public: + ReverseWrapper(Iterable& iterable) : m_iterable(iterable) {} + + decltype(auto) begin() const { + return std::rbegin(m_iterable); + } + + decltype(auto) end() const { + return std::rend(m_iterable); + } +}; + +template <typename Iterable> +ReverseWrapper<Iterable> reverseIterate(Iterable& list) { + return ReverseWrapper<Iterable>(list); +} + +template <typename Functor> +class FinallyGuard { +public: + FinallyGuard(Functor functor) : functor(move(functor)), dismiss(false) {} + + FinallyGuard(FinallyGuard&& o) : functor(move(o.functor)), dismiss(o.dismiss) { + o.cancel(); + } + + FinallyGuard& operator=(FinallyGuard&& o) { + functor = move(o.functor); + dismiss = o.dismiss; + o.cancel(); + return *this; + } + + ~FinallyGuard() { + if (!dismiss) + functor(); + } + + void cancel() { + dismiss = true; + } + +private: + Functor functor; + bool dismiss; +}; + +template <typename Functor> +FinallyGuard<typename std::decay<Functor>::type> finally(Functor&& f) { + return FinallyGuard<Functor>(forward<Functor>(f)); +} + +// Generates compile time sequences of indexes from MinIndex to MaxIndex + +template <size_t...> +struct IndexSequence {}; + +template <size_t Min, size_t N, size_t... S> +struct GenIndexSequence : GenIndexSequence<Min, N - 1, N - 1, S...> {}; + +template <size_t Min, size_t... S> +struct GenIndexSequence<Min, Min, S...> { + typedef IndexSequence<S...> type; +}; + +// Apply a tuple as individual arguments to a function + +template <typename Function, typename Tuple, size_t... Indexes> +decltype(auto) tupleUnpackFunctionIndexes(Function&& function, Tuple&& args, IndexSequence<Indexes...> const&) { + return function(get<Indexes>(forward<Tuple>(args))...); +} + +template <typename Function, typename Tuple> +decltype(auto) tupleUnpackFunction(Function&& function, Tuple&& args) { + return tupleUnpackFunctionIndexes<Function, Tuple>(forward<Function>(function), forward<Tuple>(args), + typename GenIndexSequence<0, std::tuple_size<typename std::decay<Tuple>::type>::value>::type()); +} + +// Apply a function to every element of a tuple. This will NOT happen in a +// predictable order! + +template <typename Function, typename Tuple, size_t... Indexes> +decltype(auto) tupleApplyFunctionIndexes(Function&& function, Tuple&& args, IndexSequence<Indexes...> const&) { + return make_tuple(function(get<Indexes>(forward<Tuple>(args)))...); +} + +template <typename Function, typename Tuple> +decltype(auto) tupleApplyFunction(Function&& function, Tuple&& args) { + return tupleApplyFunctionIndexes<Function, Tuple>(forward<Function>(function), forward<Tuple>(args), + typename GenIndexSequence<0, std::tuple_size<typename std::decay<Tuple>::type>::value>::type()); +} + +// Use this version if you do not care about the return value of the function +// or your function returns void. This version DOES happen in a predictable +// order, first argument first, last argument last. + +template <typename Function, typename Tuple> +void tupleCallFunctionCaller(Function&&, Tuple&&) {} + +template <typename Tuple, typename Function, typename First, typename... Rest> +void tupleCallFunctionCaller(Tuple&& t, Function&& function) { + tupleCallFunctionCaller<Tuple, Function, Rest...>(forward<Tuple>(t), forward<Function>(function)); + function(get<sizeof...(Rest)>(forward<Tuple>(t))); +} + +template <typename Tuple, typename Function, typename... T> +void tupleCallFunctionExpander(Tuple&& t, Function&& function, tuple<T...> const&) { + tupleCallFunctionCaller<Tuple, Function, T...>(forward<Tuple>(t), forward<Function>(function)); +} + +template <typename Tuple, typename Function> +void tupleCallFunction(Tuple&& t, Function&& function) { + tupleCallFunctionExpander<Tuple, Function>(forward<Tuple>(t), forward<Function>(function), forward<Tuple>(t)); +} + +// Get a subset of a tuple + +template <typename Tuple, size_t... Indexes> +decltype(auto) subTupleIndexes(Tuple&& t, IndexSequence<Indexes...> const&) { + return make_tuple(get<Indexes>(forward<Tuple>(t))...); +} + +template <size_t Min, size_t Size, typename Tuple> +decltype(auto) subTuple(Tuple&& t) { + return subTupleIndexes(forward<Tuple>(t), GenIndexSequence<Min, Size>::type()); +} + +template <size_t Trim, typename Tuple> +decltype(auto) trimTuple(Tuple&& t) { + return subTupleIndexes(forward<Tuple>(t), typename GenIndexSequence<Trim, std::tuple_size<typename std::decay<Tuple>::type>::value>::type()); +} + +// Unpack a parameter expansion into a container + +template <typename Container> +void unpackVariadicImpl(Container&) {} + +template <typename Container, typename TFirst, typename... TRest> +void unpackVariadicImpl(Container& container, TFirst&& tfirst, TRest&&... trest) { + container.insert(container.end(), forward<TFirst>(tfirst)); + unpackVariadicImpl(container, forward<TRest>(trest)...); +} + +template <typename Container, typename... T> +Container unpackVariadic(T&&... t) { + Container c; + unpackVariadicImpl(c, forward<T>(t)...); + return c; +} + +// Call a function on each entry in a variadic parameter set + +template <typename Function> +void callFunctionVariadic(Function&&) {} + +template <typename Function, typename Arg1, typename... ArgRest> +void callFunctionVariadic(Function&& function, Arg1&& arg1, ArgRest&&... argRest) { + function(arg1); + callFunctionVariadic(forward<Function>(function), forward<ArgRest>(argRest)...); +} + +template <typename... Rest> +struct VariadicTypedef; + +template <> +struct VariadicTypedef<> {}; + +template <typename FirstT, typename... RestT> +struct VariadicTypedef<FirstT, RestT...> { + typedef FirstT First; + typedef VariadicTypedef<RestT...> Rest; +}; + +// For generic types, directly use the result of the signature of its +// 'operator()' +template <typename T> +struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {}; + +template <typename ReturnType, typename... ArgsTypes> +struct FunctionTraits<ReturnType(ArgsTypes...)> { + // arity is the number of arguments. + static constexpr size_t Arity = sizeof...(ArgsTypes); + + typedef ReturnType Return; + + typedef VariadicTypedef<ArgsTypes...> Args; + typedef tuple<ArgsTypes...> ArgTuple; + + template <size_t i> + struct Arg { + // the i-th argument is equivalent to the i-th tuple element of a tuple + // composed of those arguments. + typedef typename tuple_element<i, ArgTuple>::type type; + }; +}; + +template <typename ReturnType, typename... Args> +struct FunctionTraits<ReturnType (*)(Args...)> : public FunctionTraits<ReturnType(Args...)> {}; + +template <typename FunctionType> +struct FunctionTraits<std::function<FunctionType>> : public FunctionTraits<FunctionType> {}; + +template <typename ClassType, typename ReturnType, typename... Args> +struct FunctionTraits<ReturnType (ClassType::*)(Args...)> : public FunctionTraits<ReturnType(Args...)> { + typedef ClassType& OwnerType; +}; + +template <typename ClassType, typename ReturnType, typename... Args> +struct FunctionTraits<ReturnType (ClassType::*)(Args...) const> : public FunctionTraits<ReturnType(Args...)> { + typedef const ClassType& OwnerType; +}; + +template <typename T> +struct FunctionTraits<T&> : public FunctionTraits<T> {}; + +template <typename T> +struct FunctionTraits<T const&> : public FunctionTraits<T> {}; + +template <typename T> +struct FunctionTraits<T&&> : public FunctionTraits<T> {}; + +template <typename T> +struct FunctionTraits<T const&&> : public FunctionTraits<T> {}; + +} + +#endif diff --git a/source/core/StarArray.hpp b/source/core/StarArray.hpp new file mode 100644 index 0000000..1b3e150 --- /dev/null +++ b/source/core/StarArray.hpp @@ -0,0 +1,254 @@ +#ifndef STAR_ARRAY_H +#define STAR_ARRAY_H + +#include <array> + +#include "StarHash.hpp" + +namespace Star { + +// Somewhat nicer form of std::array, always initializes values, uses nicer +// constructor pattern. +template <typename ElementT, size_t SizeN> +class Array : public std::array<ElementT, SizeN> { +public: + typedef std::array<ElementT, SizeN> Base; + + typedef ElementT Element; + static size_t const ArraySize = SizeN; + + typedef Element* iterator; + typedef Element const* const_iterator; + + typedef Element& reference; + typedef Element const& const_reference; + + typedef Element value_type; + + static Array filled(Element const& e); + + template <typename Iterator> + static Array copyFrom(Iterator p, size_t n = NPos); + + Array(); + + explicit Array(Element const& e1); + + template <typename... T> + Array(Element const& e1, T const&... rest); + + template <typename Element2> + explicit Array(Array<Element2, SizeN> const& a); + + template <size_t i> + reference get(); + + template <size_t i> + const_reference get() const; + + template <typename T2> + Array& operator=(Array<T2, SizeN> const& array); + + Element* ptr(); + Element const* ptr() const; + + bool operator==(Array const& a) const; + bool operator!=(Array const& a) const; + bool operator<(Array const& a) const; + bool operator<=(Array const& a) const; + bool operator>(Array const& a) const; + bool operator>=(Array const& a) const; + + template <size_t Size2> + Array<ElementT, Size2> toSize() const; + +private: + // Instead of {} array initialization, use recursive assignment to mimic old + // C++ style construction with less strict narrowing rules. + template <typename T, typename... TL> + void set(T const& e, TL const&... rest); + void set(); +}; + +template <typename DataT, size_t SizeT> +struct hash<Array<DataT, SizeT>> { + size_t operator()(Array<DataT, SizeT> const& a) const; + Star::hash<DataT> dataHasher; +}; + +typedef Array<int, 2> Array2I; +typedef Array<size_t, 2> Array2S; +typedef Array<unsigned, 2> Array2U; +typedef Array<float, 2> Array2F; +typedef Array<double, 2> Array2D; + +typedef Array<int, 3> Array3I; +typedef Array<size_t, 3> Array3S; +typedef Array<unsigned, 3> Array3U; +typedef Array<float, 3> Array3F; +typedef Array<double, 3> Array3D; + +typedef Array<int, 4> Array4I; +typedef Array<size_t, 4> Array4S; +typedef Array<unsigned, 4> Array4U; +typedef Array<float, 4> Array4F; +typedef Array<double, 4> Array4D; + +template <typename Element, size_t Size> +Array<Element, Size> Array<Element, Size>::filled(Element const& e) { + Array a; + a.fill(e); + return a; +} + +template <typename Element, size_t Size> +template <typename Iterator> +Array<Element, Size> Array<Element, Size>::copyFrom(Iterator p, size_t n) { + Array a; + for (size_t i = 0; i < n && i < Size; ++i) + a[i] = *(p++); + return a; +} + +template <typename Element, size_t Size> +Array<Element, Size>::Array() + : Base() {} + +template <typename Element, size_t Size> +Array<Element, Size>::Array(Element const& e1) { + static_assert(Size == 1, "Incorrect size in Array constructor"); + set(e1); +} + +template <typename Element, size_t Size> +template <typename... T> +Array<Element, Size>::Array(Element const& e1, T const&... rest) { + static_assert(sizeof...(rest) == Size - 1, "Incorrect size in Array constructor"); + set(e1, rest...); +} + +template <typename Element, size_t Size> +template <typename Element2> +Array<Element, Size>::Array(Array<Element2, Size> const& a) { + std::copy(a.begin(), a.end(), Base::begin()); +} + +template <typename Element, size_t Size> +template <size_t i> +auto Array<Element, Size>::get() -> reference { + static_assert(i < Size, "Incorrect size in Array::at"); + return Base::operator[](i); +} + +template <typename Element, size_t Size> +template <size_t i> +auto Array<Element, Size>::get() const -> const_reference { + static_assert(i < Size, "Incorrect size in Array::at"); + return Base::operator[](i); +} + +template <typename Element, size_t Size> +template <typename T2> +Array<Element, Size>& Array<Element, Size>::operator=(Array<T2, Size> const& array) { + std::copy(array.begin(), array.end(), Base::begin()); + return *this; +} + +template <typename Element, size_t Size> +Element* Array<Element, Size>::ptr() { + return Base::data(); +} + +template <typename Element, size_t Size> +Element const* Array<Element, Size>::ptr() const { + return Base::data(); +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator==(Array const& a) const { + for (size_t i = 0; i < Size; ++i) + if ((*this)[i] != a[i]) + return false; + return true; +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator!=(Array const& a) const { + return !operator==(a); +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator<(Array const& a) const { + for (size_t i = 0; i < Size; ++i) { + if ((*this)[i] < a[i]) + return true; + else if (a[i] < (*this)[i]) + return false; + } + return false; +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator<=(Array const& a) const { + for (size_t i = 0; i < Size; ++i) { + if ((*this)[i] < a[i]) + return true; + else if (a[i] < (*this)[i]) + return false; + } + return true; +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator>(Array const& a) const { + return a < *this; +} + +template <typename Element, size_t Size> +bool Array<Element, Size>::operator>=(Array const& a) const { + return a <= *this; +} + +template <typename Element, size_t Size> +template <size_t Size2> +Array<Element, Size2> Array<Element, Size>::toSize() const { + Array<Element, Size2> r; + size_t ns = std::min(Size2, Size); + for (size_t i = 0; i < ns; ++i) + r[i] = (*this)[i]; + return r; +} + +template <typename Element, size_t Size> +void Array<Element, Size>::set() {} + +template <typename Element, size_t Size> +template <typename T, typename... TL> +void Array<Element, Size>::set(T const& e, TL const&... rest) { + Base::operator[](Size - 1 - sizeof...(rest)) = e; + set(rest...); +} + +template <typename Element, size_t Size> +std::ostream& operator<<(std::ostream& os, Array<Element, Size> const& a) { + os << '['; + for (size_t i = 0; i < Size; ++i) { + os << a[i]; + if (i != Size - 1) + os << ", "; + } + os << ']'; + return os; +} + +template <typename DataT, size_t SizeT> +size_t hash<Array<DataT, SizeT>>::operator()(Array<DataT, SizeT> const& a) const { + size_t hashval = 0; + for (size_t i = 0; i < SizeT; ++i) + hashCombine(hashval, dataHasher(a[i])); + return hashval; +} + +} + +#endif diff --git a/source/core/StarAtomicSharedPtr.hpp b/source/core/StarAtomicSharedPtr.hpp new file mode 100644 index 0000000..e85ffd5 --- /dev/null +++ b/source/core/StarAtomicSharedPtr.hpp @@ -0,0 +1,121 @@ +#ifndef STAR_ATOMIC_SHARED_PTR_HPP +#define STAR_ATOMIC_SHARED_PTR_HPP + +#include "StarThread.hpp" + +namespace Star { + +// Thread safe shared_ptr such that is is possible to safely access the +// contents of the shared_ptr while other threads might be updating it. Makes +// it possible to safely do Read Copy Update. +template <typename T> +class AtomicSharedPtr { +public: + typedef shared_ptr<T> SharedPtr; + typedef weak_ptr<T> WeakPtr; + + AtomicSharedPtr(); + AtomicSharedPtr(AtomicSharedPtr const& p); + AtomicSharedPtr(AtomicSharedPtr&& p); + AtomicSharedPtr(SharedPtr p); + + SharedPtr load() const; + WeakPtr weak() const; + void store(SharedPtr p); + void reset(); + + explicit operator bool() const; + bool unique() const; + + SharedPtr operator->() const; + + AtomicSharedPtr& operator=(AtomicSharedPtr const& p); + AtomicSharedPtr& operator=(AtomicSharedPtr&& p); + AtomicSharedPtr& operator=(SharedPtr p); + +private: + SharedPtr m_ptr; + mutable SpinLock m_lock; +}; + +template <typename T> +AtomicSharedPtr<T>::AtomicSharedPtr() {} + +template <typename T> +AtomicSharedPtr<T>::AtomicSharedPtr(AtomicSharedPtr const& p) + : m_ptr(p.load()) {} + +template <typename T> +AtomicSharedPtr<T>::AtomicSharedPtr(AtomicSharedPtr&& p) + : m_ptr(move(p.m_ptr)) {} + +template <typename T> +AtomicSharedPtr<T>::AtomicSharedPtr(SharedPtr p) + : m_ptr(move(p)) {} + +template <typename T> +auto AtomicSharedPtr<T>::load() const -> SharedPtr { + SpinLocker locker(m_lock); + return m_ptr; +} + +template <typename T> +auto AtomicSharedPtr<T>::weak() const -> WeakPtr { + SpinLocker locker(m_lock); + return WeakPtr(m_ptr); +} + +template <typename T> +void AtomicSharedPtr<T>::store(SharedPtr p) { + SpinLocker locker(m_lock); + m_ptr = move(p); +} + +template <typename T> +void AtomicSharedPtr<T>::reset() { + SpinLocker locker(m_lock); + m_ptr.reset(); +} + +template <typename T> +AtomicSharedPtr<T>::operator bool() const { + SpinLocker locker(m_lock); + return (bool)m_ptr; +} + +template <typename T> +bool AtomicSharedPtr<T>::unique() const { + SpinLocker locker(m_lock); + return m_ptr.unique(); +} + +template <typename T> +auto AtomicSharedPtr<T>::operator-> () const -> SharedPtr { + SpinLocker locker(m_lock); + return m_ptr; +} + +template <typename T> +AtomicSharedPtr<T>& AtomicSharedPtr<T>::operator=(AtomicSharedPtr const& p) { + SpinLocker locker(m_lock); + m_ptr = p.load(); + return *this; +} + +template <typename T> +AtomicSharedPtr<T>& AtomicSharedPtr<T>::operator=(AtomicSharedPtr&& p) { + SpinLocker locker(m_lock); + m_ptr = move(p.m_ptr); + return *this; +} + +template <typename T> +AtomicSharedPtr<T>& AtomicSharedPtr<T>::operator=(SharedPtr p) { + SpinLocker locker(m_lock); + m_ptr = move(p); + return *this; +} + +} + +#endif diff --git a/source/core/StarAudio.cpp b/source/core/StarAudio.cpp new file mode 100644 index 0000000..7e4cd31 --- /dev/null +++ b/source/core/StarAudio.cpp @@ -0,0 +1,562 @@ +// Fixes unused variable warning +#define OV_EXCLUDE_STATIC_CALLBACKS + +#include "vorbis/codec.h" +#include "vorbis/vorbisfile.h" + +#include "StarAudio.hpp" +#include "StarBuffer.hpp" +#include "StarFile.hpp" +#include "StarFormat.hpp" +#include "StarLogging.hpp" +#include "StarDataStreamDevices.hpp" + +namespace Star { + +namespace { + struct WaveData { + ByteArrayPtr byteArray; + unsigned channels; + unsigned sampleRate; + }; + + template <typename T> + T readLEType(IODevicePtr const& device) { + T t; + device->readFull((char*)&t, sizeof(t)); + fromByteOrder(ByteOrder::LittleEndian, (char*)&t, sizeof(t)); + return t; + } + + bool isUncompressed(IODevicePtr device) { + const size_t sigLength = 4; + unique_ptr<char[]> riffSig(new char[sigLength + 1]()); // RIFF\0 + unique_ptr<char[]> waveSig(new char[sigLength + 1]()); // WAVE\0 + + StreamOffset previousOffset = device->pos(); + device->seek(0); + device->readFull(riffSig.get(), sigLength); + device->seek(4, IOSeek::Relative); + device->readFull(waveSig.get(), sigLength); + device->seek(previousOffset); + if (strcmp(riffSig.get(), "RIFF") == 0 && strcmp(waveSig.get(), "WAVE") == 0) { // bytes are magic + return true; + } + return false; + } + + WaveData parseWav(IODevicePtr device) { + const size_t sigLength = 4; + unique_ptr<char[]> riffSig(new char[sigLength + 1]()); // RIFF\0 + unique_ptr<char[]> waveSig(new char[sigLength + 1]()); // WAVE\0 + unique_ptr<char[]> fmtSig(new char[sigLength + 1]()); // fmt \0 + unique_ptr<char[]> dataSig(new char[sigLength + 1]()); // data\0 + + // RIFF Chunk Descriptor + device->seek(0); + device->readFull(riffSig.get(), sigLength); + + uint32_t fileSize = readLEType<uint32_t>(device); + fileSize += sigLength + sizeof(fileSize); + if (fileSize != device->size()) + throw AudioException(strf("Wav file is wrong size, reports %d is actually %d", fileSize, device->size())); + + device->readFull(waveSig.get(), sigLength); + + if ((strcmp(riffSig.get(), "RIFF") != 0) || (strcmp(waveSig.get(), "WAVE") != 0)) { // bytes are not magic + auto p = [](char a) { return isprint(a) ? a : '?'; }; + throw AudioException(strf("Wav file has wrong magic bytes, got `%c%c%c%c' and `%c%c%c%c' but expected `RIFF' and `WAVE'", + p(riffSig[0]), p(riffSig[1]), p(riffSig[2]), p(riffSig[3]), p(waveSig[0]), p(waveSig[1]), p(waveSig[2]), p(waveSig[3]))); + } + + // fmt subchunk + + device->readFull(fmtSig.get(), sigLength); + if (strcmp(fmtSig.get(), "fmt ") != 0) { // friendship is magic + auto p = [](char a) { return isprint(a) ? a : '?'; }; + throw AudioException(strf("Wav file fmt subchunk has wrong magic bytes, got `%c%c%c%c' but expected `fmt '", + p(fmtSig[0]), + p(fmtSig[1]), + p(fmtSig[2]), + p(fmtSig[3]))); + } + + uint32_t fmtSubchunkSize = readLEType<uint32_t>(device); + fmtSubchunkSize += sigLength; + if (fmtSubchunkSize < 20) + throw AudioException(strf("fmt subchunk is sized wrong, expected 20 got %d. Is this wav file not PCM?", fmtSubchunkSize)); + + uint16_t audioFormat = readLEType<uint16_t>(device); + if (audioFormat != 1) + throw AudioException("audioFormat data indicates that wav file is something other than PCM format. Unsupported."); + + uint16_t wavChannels = readLEType<uint16_t>(device); + uint32_t wavSampleRate = readLEType<uint32_t>(device); + uint32_t wavByteRate = readLEType<uint32_t>(device); + uint16_t wavBlockAlign = readLEType<uint16_t>(device); + uint16_t wavBitsPerSample = readLEType<uint16_t>(device); + + if (wavBitsPerSample != 16) + throw AudioException("Only 16-bit PCM wavs are supported."); + if (wavByteRate * 8 != wavSampleRate * wavChannels * wavBitsPerSample) + throw AudioException("Sanity check failed, ByteRate is wrong"); + if (wavBlockAlign * 8 != wavChannels * wavBitsPerSample) + throw AudioException("Sanity check failed, BlockAlign is wrong"); + + device->seek(fmtSubchunkSize - 20, IOSeek::Relative); + + // data subchunk + + device->readFull(dataSig.get(), sigLength); + if (strcmp(dataSig.get(), "data") != 0) { // magic or more magic? + auto p = [](char a) { return isprint(a) ? a : '?'; }; + throw AudioException(strf("Wav file data subchunk has wrong magic bytes, got `%c%c%c%c' but expected `data'", + p(dataSig[0]), p(dataSig[1]), p(dataSig[2]), p(dataSig[3]))); + } + + uint32_t wavDataSize = readLEType<uint32_t>(device); + size_t wavDataOffset = (size_t)device->pos(); + if (wavDataSize + wavDataOffset > (size_t)device->size()) { + throw AudioException(strf("Wav file data size reported is inconsistent with file size, got %d but expected %d", + device->size(), wavDataSize + wavDataOffset)); + } + + ByteArrayPtr pcmData = make_shared<ByteArray>(); + pcmData->resize(wavDataSize); + + // Copy across data and perform and endianess conversion if needed + device->readFull(pcmData->ptr(), pcmData->size()); + for (size_t i = 0; i < pcmData->size() / 2; ++i) + fromByteOrder(ByteOrder::LittleEndian, pcmData->ptr() + i * 2, 2); + + return WaveData{move(pcmData), wavChannels, wavSampleRate}; + } +} + +class CompressedAudioImpl { +public: + static size_t readFunc(void* ptr, size_t size, size_t nmemb, void* datasource) { + return static_cast<ExternalBuffer*>(datasource)->read((char*)ptr, size * nmemb) / size; + } + + static int seekFunc(void* datasource, ogg_int64_t offset, int whence) { + static_cast<ExternalBuffer*>(datasource)->seek(offset, (IOSeek)whence); + return 0; + }; + + static long int tellFunc(void* datasource) { + return (long int)static_cast<ExternalBuffer*>(datasource)->pos(); + }; + + CompressedAudioImpl(CompressedAudioImpl const& impl) { + m_audioData = impl.m_audioData; + m_memoryFile.reset(m_audioData->ptr(), m_audioData->size()); + m_vorbisInfo = nullptr; + } + + CompressedAudioImpl(IODevicePtr audioData) { + audioData->open(IOMode::Read); + audioData->seek(0); + m_audioData = make_shared<ByteArray>(audioData->readBytes((size_t)audioData->size())); + m_memoryFile.reset(m_audioData->ptr(), m_audioData->size()); + m_vorbisInfo = nullptr; + } + + ~CompressedAudioImpl() { + ov_clear(&m_vorbisFile); + } + + bool open() { + m_callbacks.read_func = readFunc; + m_callbacks.seek_func = seekFunc; + m_callbacks.tell_func = tellFunc; + m_callbacks.close_func = NULL; + + if (ov_open_callbacks(&m_memoryFile, &m_vorbisFile, NULL, 0, m_callbacks) < 0) + return false; + + m_vorbisInfo = ov_info(&m_vorbisFile, -1); + return true; + } + + unsigned channels() { + return m_vorbisInfo->channels; + } + + unsigned sampleRate() { + return m_vorbisInfo->rate; + } + + double totalTime() { + return ov_time_total(&m_vorbisFile, -1); + } + + uint64_t totalSamples() { + return ov_pcm_total(&m_vorbisFile, -1); + } + + void seekTime(double time) { + int ret = ov_time_seek(&m_vorbisFile, time); + + if (ret != 0) + throw StarException("Cannot seek ogg stream Audio::seekTime"); + } + + void seekSample(uint64_t pos) { + int ret = ov_pcm_seek(&m_vorbisFile, pos); + + if (ret != 0) + throw StarException("Cannot seek ogg stream in Audio::seekSample"); + } + + double currentTime() { + return ov_time_tell(&m_vorbisFile); + } + + uint64_t currentSample() { + return ov_pcm_tell(&m_vorbisFile); + } + + size_t readPartial(int16_t* buffer, size_t bufferSize) { + int bitstream; + int read; + // ov_read takes int parameter, so do some magic here to make sure we don't + // overflow + bufferSize *= 2; +#if STAR_LITTLE_ENDIAN + read = ov_read(&m_vorbisFile, (char*)buffer, bufferSize, 0, 2, 1, &bitstream); +#else + read = ov_read(&m_vorbisFile, (char*)buffer, bufferSize, 1, 2, 1, &bitstream); +#endif + if (read < 0) + throw AudioException("Error in Audio::read"); + + // read in bytes, returning number of int16_t samples. + return read / 2; + } + +private: + ByteArrayConstPtr m_audioData; + ExternalBuffer m_memoryFile; + ov_callbacks m_callbacks; + OggVorbis_File m_vorbisFile; + vorbis_info* m_vorbisInfo; +}; + +class UncompressedAudioImpl { +public: + UncompressedAudioImpl(UncompressedAudioImpl const& impl) { + m_channels = impl.m_channels; + m_sampleRate = impl.m_sampleRate; + m_audioData = impl.m_audioData; + m_memoryFile.reset(m_audioData->ptr(), m_audioData->size()); + } + + UncompressedAudioImpl(CompressedAudioImpl& impl) { + m_channels = impl.channels(); + m_sampleRate = impl.sampleRate(); + + int16_t buffer[1024]; + Buffer uncompressBuffer; + while (true) { + size_t ramt = impl.readPartial(buffer, 1024); + + if (ramt == 0) { + // End of stream reached + break; + } else { + uncompressBuffer.writeFull((char*)buffer, ramt * 2); + } + } + + m_audioData = make_shared<ByteArray>(uncompressBuffer.takeData()); + m_memoryFile.reset(m_audioData->ptr(), m_audioData->size()); + } + + UncompressedAudioImpl(ByteArrayConstPtr data, unsigned channels, unsigned sampleRate) { + m_channels = channels; + m_sampleRate = sampleRate; + m_audioData = move(data); + m_memoryFile.reset(m_audioData->ptr(), m_audioData->size()); + } + + bool open() { + return true; + } + + unsigned channels() { + return m_channels; + } + + unsigned sampleRate() { + return m_sampleRate; + } + + double totalTime() { + return (double)totalSamples() / m_sampleRate; + } + + uint64_t totalSamples() { + return m_memoryFile.dataSize() / 2 / m_channels; + } + + void seekTime(double time) { + seekSample((uint64_t)(time * m_sampleRate)); + } + + void seekSample(uint64_t pos) { + m_memoryFile.seek(pos * 2 * m_channels); + } + + double currentTime() { + return (double)currentSample() / m_sampleRate; + } + + uint64_t currentSample() { + return m_memoryFile.pos() / 2 / m_channels; + } + + size_t readPartial(int16_t* buffer, size_t bufferSize) { + if (bufferSize != NPos) + bufferSize = bufferSize * 2; + return m_memoryFile.read((char*)buffer, bufferSize) / 2; + } + +private: + unsigned m_channels; + unsigned m_sampleRate; + ByteArrayConstPtr m_audioData; + ExternalBuffer m_memoryFile; +}; + +Audio::Audio(IODevicePtr device) { + if (!device->isOpen()) + device->open(IOMode::Read); + + if (isUncompressed(device)) { + WaveData data = parseWav(device); + m_uncompressed = make_shared<UncompressedAudioImpl>(move(data.byteArray), data.channels, data.sampleRate); + } else { + m_compressed = make_shared<CompressedAudioImpl>(device); + if (!m_compressed->open()) + throw AudioException("File does not appear to be a valid ogg bitstream"); + } +} + +Audio::Audio(Audio const& audio) { + *this = audio; +} + +Audio::Audio(Audio&& audio) { + operator=(move(audio)); +} + +Audio& Audio::operator=(Audio const& audio) { + if (audio.m_uncompressed) { + m_uncompressed = make_shared<UncompressedAudioImpl>(*audio.m_uncompressed); + m_uncompressed->open(); + } else { + m_compressed = make_shared<CompressedAudioImpl>(*audio.m_compressed); + m_compressed->open(); + } + + seekSample(audio.currentSample()); + return *this; +} + +Audio& Audio::operator=(Audio&& audio) { + m_compressed = move(audio.m_compressed); + m_uncompressed = move(audio.m_uncompressed); + return *this; +} + +unsigned Audio::channels() const { + if (m_uncompressed) + return m_uncompressed->channels(); + else + return m_compressed->channels(); +} + +unsigned Audio::sampleRate() const { + if (m_uncompressed) + return m_uncompressed->sampleRate(); + else + return m_compressed->sampleRate(); +} + +double Audio::totalTime() const { + if (m_uncompressed) + return m_uncompressed->totalTime(); + else + return m_compressed->totalTime(); +} + +uint64_t Audio::totalSamples() const { + if (m_uncompressed) + return m_uncompressed->totalSamples(); + else + return m_compressed->totalSamples(); +} + +bool Audio::compressed() const { + return (bool)m_compressed; +} + +void Audio::uncompress() { + if (m_compressed) { + m_uncompressed = make_shared<UncompressedAudioImpl>(*m_compressed); + m_compressed.reset(); + } +} + +void Audio::seekTime(double time) { + if (m_uncompressed) + m_uncompressed->seekTime(time); + else + m_compressed->seekTime(time); +} + +void Audio::seekSample(uint64_t pos) { + if (m_uncompressed) + m_uncompressed->seekSample(pos); + else + m_compressed->seekSample(pos); +} + +double Audio::currentTime() const { + if (m_uncompressed) + return m_uncompressed->currentTime(); + else + return m_compressed->currentTime(); +} + +uint64_t Audio::currentSample() const { + if (m_uncompressed) + return m_uncompressed->currentSample(); + else + return m_compressed->currentSample(); +} + +size_t Audio::readPartial(int16_t* buffer, size_t bufferSize) { + if (bufferSize == 0) + return 0; + + if (m_uncompressed) + return m_uncompressed->readPartial(buffer, bufferSize); + else + return m_compressed->readPartial(buffer, bufferSize); +} + +size_t Audio::read(int16_t* buffer, size_t bufferSize) { + if (bufferSize == 0) + return 0; + + size_t readTotal = 0; + while (readTotal < bufferSize) { + size_t toGo = bufferSize - readTotal; + size_t ramt = readPartial(buffer + readTotal, toGo); + readTotal += ramt; + // End of stream reached + if (ramt == 0) + break; + } + return readTotal; +} + +size_t Audio::resample(unsigned destinationChannels, unsigned destinationSampleRate, int16_t* destinationBuffer, size_t destinationBufferSize, double velocity) { + unsigned destinationSamples = destinationBufferSize / destinationChannels; + if (destinationSamples == 0) + return 0; + + unsigned sourceChannels = channels(); + unsigned sourceSampleRate = sampleRate(); + + if (velocity != 1.0) + sourceSampleRate = (unsigned)(sourceSampleRate * velocity); + + if (destinationChannels == sourceChannels && destinationSampleRate == sourceSampleRate) { + // If the destination and source channel count and sample rate are the + // same, this is the same as a read. + + return read(destinationBuffer, destinationBufferSize); + + } else if (destinationSampleRate == sourceSampleRate) { + // If the destination and source sample rate are the same, then we can skip + // the super-sampling math. + + unsigned sourceBufferSize = destinationSamples * sourceChannels; + + m_workingBuffer.resize(sourceBufferSize * sizeof(int16_t)); + int16_t* sourceBuffer = (int16_t*)m_workingBuffer.ptr(); + + unsigned readSamples = read(sourceBuffer, sourceBufferSize) / sourceChannels; + + for (unsigned sample = 0; sample < readSamples; ++sample) { + unsigned sourceBufferIndex = sample * sourceChannels; + unsigned destinationBufferIndex = sample * destinationChannels; + + for (unsigned destinationChannel = 0; destinationChannel < destinationChannels; ++destinationChannel) { + // If the destination channel count is greater than the source + // channels, simply copy the last channel + unsigned sourceChannel = min(destinationChannel, sourceChannels - 1); + destinationBuffer[destinationBufferIndex + destinationChannel] = + sourceBuffer[sourceBufferIndex + sourceChannel]; + } + } + + return readSamples * destinationChannels; + + } else { + // Otherwise, we have to do a full resample. + + unsigned sourceSamples = ((uint64_t)sourceSampleRate * destinationSamples + destinationSampleRate - 1) / destinationSampleRate; + unsigned sourceBufferSize = sourceSamples * sourceChannels; + + m_workingBuffer.resize(sourceBufferSize * sizeof(int16_t)); + int16_t* sourceBuffer = (int16_t*)m_workingBuffer.ptr(); + + unsigned readSamples = read(sourceBuffer, sourceBufferSize) / sourceChannels; + + if (readSamples == 0) + return 0; + + unsigned writtenSamples = 0; + + for (unsigned destinationSample = 0; destinationSample < destinationSamples; ++destinationSample) { + unsigned destinationBufferIndex = destinationSample * destinationChannels; + + for (unsigned destinationChannel = 0; destinationChannel < destinationChannels; ++destinationChannel) { + static int const SuperSampleFactor = 8; + + // If the destination channel count is greater than the source + // channels, simply copy the last channel + unsigned sourceChannel = min(destinationChannel, sourceChannels - 1); + + int sample = 0; + int sampleCount = 0; + for (int superSample = 0; superSample < SuperSampleFactor; ++superSample) { + unsigned sourceSample = (unsigned)((destinationSample * SuperSampleFactor + superSample) * sourceSamples / destinationSamples) / SuperSampleFactor; + if (sourceSample < readSamples) { + unsigned sourceBufferIndex = sourceSample * sourceChannels; + starAssert(sourceBufferIndex + sourceChannel < sourceBufferSize); + sample += sourceBuffer[sourceBufferIndex + sourceChannel]; + ++sampleCount; + } + } + + // If sampleCount is zero, then we are past the end of our read data + // completely, and can stop + if (sampleCount == 0) + return writtenSamples * destinationChannels; + + sample /= sampleCount; + destinationBuffer[destinationBufferIndex + destinationChannel] = (int16_t)sample; + writtenSamples = destinationSample + 1; + } + } + + return writtenSamples * destinationChannels; + } +} + +} diff --git a/source/core/StarAudio.hpp b/source/core/StarAudio.hpp new file mode 100644 index 0000000..eda28d9 --- /dev/null +++ b/source/core/StarAudio.hpp @@ -0,0 +1,95 @@ +#ifndef STAR_AUDIO_HPP +#define STAR_AUDIO_HPP + +#include "StarIODevice.hpp" + +namespace Star { + +STAR_CLASS(CompressedAudioImpl); +STAR_CLASS(UncompressedAudioImpl); +STAR_CLASS(Audio); + +STAR_EXCEPTION(AudioException, StarException); + +// Simple class for reading audio files in ogg/vorbis and wav format. +// Reads and allows for decompression of a limited subset of ogg/vorbis. Does +// not handle multiple bitstreams, sample rate or channel number changes. +// Entire stream is kept in memory, and is implicitly shared so copying Audio +// instances is not expensive. +class Audio { +public: + explicit Audio(IODevicePtr device); + Audio(Audio const& audio); + Audio(Audio&& audio); + + Audio& operator=(Audio const& audio); + Audio& operator=(Audio&& audio); + + // This function returns the number of channels that this file has. Channels + // are static throughout file. + unsigned channels() const; + + // This function returns the sample rate that this file has. Sample rates + // are static throughout file. + unsigned sampleRate() const; + + // This function returns the playtime duration of the file. + double totalTime() const; + + // This function returns total number of samples in this file. + uint64_t totalSamples() const; + + // This function returns true when the datastream or file being read from is + // a vorbis compressed file. False otherwise. + bool compressed() const; + + // If compressed, permanently uncompresses audio for faster reading. The + // uncompressed buffer is shared with all further copies of Audio, and this + // is irreversible. + void uncompress(); + + // This function seeks the data stream to the given time in seconds. + void seekTime(double time); + + // This function seeks the data stream to the given sample number + void seekSample(uint64_t sample); + + // This function converts the current offset of the file to the time value of + // that offset in seconds. + double currentTime() const; + + // This function converts the current offset of the file to the current + // sample number. + uint64_t currentSample() const; + + // Reads into 16 bit signed buffer with channels interleaved. Returns total + // number of samples read (counting each channel individually). 0 indicates + // end of stream. + size_t readPartial(int16_t* buffer, size_t bufferSize); + + // Same as readPartial, but repeats read attempting to fill buffer as much as + // possible + size_t read(int16_t* buffer, size_t bufferSize); + + // Read into a given buffer, while also converting into the given number of + // channels at the given sample rate and playback velocity. If the number of + // channels in the file is higher, only populates lower channels, if it is + // lower, the last channel is copied to the remaining channels. Attempts to + // fill the buffer as much as possible up to end of stream. May fail to fill + // an entire buffer depending on the destinationSampleRate / velocity / + // available samples. + size_t resample(unsigned destinationChannels, unsigned destinationSampleRate, + int16_t* destinationBuffer, size_t destinationBufferSize, + double velocity = 1.0); + +private: + // If audio is uncompressed, this will be null. + CompressedAudioImplPtr m_compressed; + UncompressedAudioImplPtr m_uncompressed; + + ByteArray m_workingBuffer; +}; + +} + +#endif diff --git a/source/core/StarBTree.hpp b/source/core/StarBTree.hpp new file mode 100644 index 0000000..abd8d78 --- /dev/null +++ b/source/core/StarBTree.hpp @@ -0,0 +1,937 @@ +#ifndef STAR_B_TREE_HPP +#define STAR_B_TREE_HPP + +#include "StarList.hpp" +#include "StarMaybe.hpp" + +namespace Star { + +// Mixin class for implementing a simple B+ Tree style database. LOTS of +// possibilities for improvement, especially in batch deletes / inserts. +// +// The Base class itself must have the following interface: +// +// struct Base { +// typedef KeyT Key; +// typedef DataT Data; +// typedef PointerT Pointer; +// +// // Index and Leaf types may either be a literal struct, or a pointer, or a +// // handle or whatever. They are meant to be opaque. +// typedef IndexT Index; +// typedef LeafT Leaf; +// +// Pointer rootPointer(); +// bool rootIsLeaf(); +// void setNewRoot(Pointer pointer, bool isLeaf); +// +// Index createIndex(Pointer beginPointer); +// +// // Load an existing index. +// Index loadIndex(Pointer pointer); +// +// size_t indexPointerCount(Index const& index); +// Pointer indexPointer(Index const& index, size_t i); +// void indexUpdatePointer(Index& index, size_t i, Pointer p); +// +// Key indexKeyBefore(Index const& index, size_t i); +// void indexUpdateKeyBefore(Index& index, size_t i, Key k); +// +// void indexRemoveBefore(Index& index, size_t i); +// void indexInsertAfter(Index& index, size_t i, Key k, Pointer p); +// +// size_t indexLevel(Index const& index); +// void setIndexLevel(Index& index, size_t indexLevel); +// +// // Should return true if index should try to shift elements into this index +// // from sibling index. +// bool indexNeedsShift(Index const& index); +// +// // Should return false if no shift done. If merging, always merge to the +// // left. +// bool indexShift(Index& left, Key const& mid, Index& right); +// +// // If a split has occurred, split right and return the mid-key and new +// // right node. +// Maybe<pair<Key, Index>> indexSplit(Index& index); +// +// // Index updated, needs storing. Return pointer to stored index (may +// // change). Index will not be used after store. +// Pointer storeIndex(Index index); +// +// // Index no longer part of BTree. Index will not be used after delete. +// void deleteIndex(Index index); +// +// // Should create new empty leaf. +// Leaf createLeaf(); +// +// Leaf loadLeaf(Pointer pointer); +// +// size_t leafElementCount(Leaf const& leaf); +// Key leafKey(Leaf const& leaf, size_t i); +// Data leafData(Leaf const& leaf, size_t i); +// +// void leafInsert(Leaf& leaf, size_t i, Key k, Data d); +// void leafRemove(Leaf& leaf, size_t i); +// +// // Set and get next-leaf pointers. It is not required that next-leaf +// // pointers be kept or that they be valid, so nextLeaf may return nothing. +// void setNextLeaf(Leaf& leaf, Maybe<Pointer> n); +// Maybe<Pointer> nextLeaf(Leaf const& leaf); +// +// // Should return true if leaf should try to shift elements into this leaf +// // from sibling leaf. +// bool leafNeedsShift(Leaf const& l); +// +// // Should return false if no change necessary. If merging, always merge to +// // the left. +// bool leafShift(Leaf& left, Leaf& right); +// +// // Always split right and return new right node if split occurs. +// Maybe<Leaf> leafSplit(Leaf& leaf); +// +// // Leaf has been updated, and needs to be written to storage. Return new +// // pointer (may be different). Leaf will not be used after store. +// Pointer storeLeaf(Leaf leaf); +// +// // Leaf is no longer part of this BTree. Leaf will not be used after +// // delete. +// void deleteLeaf(Leaf leaf); +// }; +template <typename Base> +class BTreeMixin : public Base { +public: + typedef typename Base::Key Key; + typedef typename Base::Data Data; + typedef typename Base::Pointer Pointer; + + typedef typename Base::Index Index; + typedef typename Base::Leaf Leaf; + + bool contains(Key const& k); + + Maybe<Data> find(Key const& k); + + // Range is inclusve on lower bound and exclusive on upper bound. + List<pair<Key, Data>> find(Key const& lower, Key const& upper); + + // Visitor is called as visitor(key, data). + template <typename Visitor> + void forEach(Key const& lower, Key const& upper, Visitor&& visitor); + + // Visitor is called as visitor(key, data). + template <typename Visitor> + void forAll(Visitor&& visitor); + + // Recover all key value pairs possible, catching exceptions during scan and + // reading as much data as possible. Visitor is called as visitor(key, data), + // ErrorHandler is called as error(char const*, std::exception const&) + template <typename Visitor, typename ErrorHandler> + void recoverAll(Visitor&& visitor, ErrorHandler&& error); + + // Visitor is called either as visitor(Index const&) or visitor(Leaf const&). + // Return false to halt traversal, true to continue. + template <typename Visitor> + void forAllNodes(Visitor&& visitor); + + // returns true if old value overwritten. + bool insert(Key k, Data data); + + // returns true if key was found. + bool remove(Key k); + + // Removes list of keys in the given range, returns count removed. + // TODO: SLOW, right now does lots of different removes separately. Need to + // implement batch inserts and deletes. + List<pair<Key, Data>> remove(Key const& lower, Key const& upper); + + uint64_t indexCount(); + uint64_t leafCount(); + uint64_t recordCount(); + + uint32_t indexLevels(); + + void createNewRoot(); + +private: + struct DataElement { + Key key; + Data data; + }; + typedef List<DataElement> DataList; + + struct DataCollector { + void operator()(Key const& k, Data const& d); + + List<pair<Key, Data>> list; + }; + + struct RecordCounter { + bool operator()(Index const& index); + bool operator()(Leaf const& leaf); + + BTreeMixin* parent; + uint64_t count; + }; + + struct IndexCounter { + bool operator()(Index const& index); + bool operator()(Leaf const&); + + BTreeMixin* parent; + uint64_t count; + }; + + struct LeafCounter { + bool operator()(Index const& index); + bool operator()(Leaf const&); + + BTreeMixin* parent; + uint64_t count; + }; + + enum ModifyAction { + InsertAction, + RemoveAction + }; + + enum ModifyState { + LeafNeedsJoin, + IndexNeedsJoin, + LeafSplit, + IndexSplit, + LeafNeedsUpdate, + IndexNeedsUpdate, + Done + }; + + struct ModifyInfo { + ModifyInfo(ModifyAction a, DataElement e); + + DataElement targetElement; + ModifyAction action; + bool found; + ModifyState state; + + Key newKey; + Pointer newPointer; + }; + + bool contains(Index const& index, Key const& k); + bool contains(Leaf const& leaf, Key const& k); + + Maybe<Data> find(Index const& index, Key const& k); + Maybe<Data> find(Leaf const& leaf, Key const& k); + + // Returns the highest key for the last leaf we have searched + template <typename Visitor> + Key forEach(Index const& index, Key const& lower, Key const& upper, Visitor&& o); + template <typename Visitor> + Key forEach(Leaf const& leaf, Key const& lower, Key const& upper, Visitor&& o); + + // Returns the highest key for the last leaf we have searched + template <typename Visitor> + Key forAll(Index const& index, Visitor&& o); + template <typename Visitor> + Key forAll(Leaf const& leaf, Visitor&& o); + + template <typename Visitor, typename ErrorHandler> + void recoverAll(Index const& index, Visitor&& o, ErrorHandler&& error); + template <typename Visitor, typename ErrorHandler> + void recoverAll(Leaf const& leaf, Visitor&& o, ErrorHandler&& error); + + // Variable size values mean that merges can happen on inserts, so can't + // split up into insert / remove methods + void modify(Leaf& leafNode, ModifyInfo& info); + void modify(Index& indexNode, ModifyInfo& info); + bool modify(DataElement e, ModifyAction action); + + // Traverses Indexes down the tree on the left side to get the least valued + // key that is pointed to by any leaf under this index. Needed when joining. + Key getLeftKey(Index const& index); + + template <typename Visitor> + void forAllNodes(Index const& index, Visitor&& visitor); + + pair<size_t, bool> leafFind(Leaf const& leaf, Key const& key); + size_t indexFind(Index const& index, Key const& key); +}; + +template <typename Base> +bool BTreeMixin<Base>::contains(Key const& k) { + if (Base::rootIsLeaf()) + return contains(Base::loadLeaf(Base::rootPointer()), k); + else + return contains(Base::loadIndex(Base::rootPointer()), k); +} + +template <typename Base> +auto BTreeMixin<Base>::find(Key const& k) -> Maybe<Data> { + if (Base::rootIsLeaf()) + return find(Base::loadLeaf(Base::rootPointer()), k); + else + return find(Base::loadIndex(Base::rootPointer()), k); +} + +template <typename Base> +auto BTreeMixin<Base>::find(Key const& lower, Key const& upper) -> List<pair<Key, Data>> { + DataCollector collector; + forEach(lower, upper, collector); + return collector.list; +} + +template <typename Base> +template <typename Visitor> +void BTreeMixin<Base>::forEach(Key const& lower, Key const& upper, Visitor&& visitor) { + if (Base::rootIsLeaf()) + forEach(Base::loadLeaf(Base::rootPointer()), lower, upper, forward<Visitor>(visitor)); + else + forEach(Base::loadIndex(Base::rootPointer()), lower, upper, forward<Visitor>(visitor)); +} + +template <typename Base> +template <typename Visitor> +void BTreeMixin<Base>::forAll(Visitor&& visitor) { + if (Base::rootIsLeaf()) + forAll(Base::loadLeaf(Base::rootPointer()), forward<Visitor>(visitor)); + else + forAll(Base::loadIndex(Base::rootPointer()), forward<Visitor>(visitor)); +} + +template <typename Base> +template <typename Visitor, typename ErrorHandler> +void BTreeMixin<Base>::recoverAll(Visitor&& visitor, ErrorHandler&& error) { + try { + if (Base::rootIsLeaf()) + recoverAll(Base::loadLeaf(Base::rootPointer()), forward<Visitor>(visitor), forward<ErrorHandler>(error)); + else + recoverAll(Base::loadIndex(Base::rootPointer()), forward<Visitor>(visitor), forward<ErrorHandler>(error)); + } catch (std::exception const& e) { + error("Error loading root index or leaf node", e); + } +} + +template <typename Base> +template <typename Visitor> +void BTreeMixin<Base>::forAllNodes(Visitor&& visitor) { + if (Base::rootIsLeaf()) + visitor(Base::loadLeaf(Base::rootPointer())); + else + forAllNodes(Base::loadIndex(Base::rootPointer()), forward<Visitor>(visitor)); +} + +template <typename Base> +bool BTreeMixin<Base>::insert(Key k, Data data) { + return modify(DataElement{move(k), move(data)}, InsertAction); +} + +template <typename Base> +bool BTreeMixin<Base>::remove(Key k) { + return modify(DataElement{move(k), Data()}, RemoveAction); +} + +template <typename Base> +auto BTreeMixin<Base>::remove(Key const& lower, Key const& upper) -> List<pair<Key, Data>> { + DataCollector collector; + forEach(lower, upper, collector); + + for (auto const& elem : collector.list) + remove(elem.first); + + return collector.list; +} + +template <typename Base> +uint64_t BTreeMixin<Base>::indexCount() { + IndexCounter counter = {this, 0}; + forAllNodes(counter); + return counter.count; +} + +template <typename Base> +uint64_t BTreeMixin<Base>::leafCount() { + LeafCounter counter = {this, 0}; + forAllNodes(counter); + return counter.count; +} + +template <typename Base> +uint64_t BTreeMixin<Base>::recordCount() { + RecordCounter counter = {this, 0}; + forAllNodes(counter); + return counter.count; +} + +template <typename Base> +uint32_t BTreeMixin<Base>::indexLevels() { + if (Base::rootIsLeaf()) + return 0; + else + return Base::indexLevel(Base::loadIndex(Base::rootPointer())) + 1; +} + +template <typename Base> +void BTreeMixin<Base>::createNewRoot() { + Base::setNewRoot(Base::storeLeaf(Base::createLeaf()), true); +} + +template <typename Base> +void BTreeMixin<Base>::DataCollector::operator()(Key const& k, Data const& d) { + list.push_back({k, d}); +} + +template <typename Base> +bool BTreeMixin<Base>::RecordCounter::operator()(Index const&) { + return true; +} + +template <typename Base> +bool BTreeMixin<Base>::RecordCounter::operator()(Leaf const& leaf) { + count += parent->leafElementCount(leaf); + return true; +} + +template <typename Base> +bool BTreeMixin<Base>::IndexCounter::operator()(Index const& index) { + ++count; + if (parent->indexLevel(index) == 0) + return false; + else + return true; +} + +template <typename Base> +bool BTreeMixin<Base>::IndexCounter::operator()(Leaf const&) { + return false; +} + +template <typename Base> +bool BTreeMixin<Base>::LeafCounter::operator()(Index const& index) { + if (parent->indexLevel(index) == 0) { + count += parent->indexPointerCount(index); + return false; + } else { + return true; + } +} + +template <typename Base> +bool BTreeMixin<Base>::LeafCounter::operator()(Leaf const&) { + return false; +} + +template <typename Base> +BTreeMixin<Base>::ModifyInfo::ModifyInfo(ModifyAction a, DataElement e) + : targetElement(move(e)), action(a) { + found = false; + state = Done; +} + +template <typename Base> +bool BTreeMixin<Base>::contains(Index const& index, Key const& k) { + size_t i = indexFind(index, k); + if (Base::indexLevel(index) == 0) + return contains(Base::loadLeaf(Base::indexPointer(index, i)), k); + else + return contains(Base::loadIndex(Base::indexPointer(index, i)), k); +} + +template <typename Base> +bool BTreeMixin<Base>::contains(Leaf const& leaf, Key const& k) { + return leafFind(leaf, k).second; +} + +template <typename Base> +auto BTreeMixin<Base>::find(Index const& index, Key const& k) -> Maybe<Data> { + size_t i = indexFind(index, k); + if (Base::indexLevel(index) == 0) + return find(Base::loadLeaf(Base::indexPointer(index, i)), k); + else + return find(Base::loadIndex(Base::indexPointer(index, i)), k); +} + +template <typename Base> +auto BTreeMixin<Base>::find(Leaf const& leaf, Key const& k) -> Maybe<Data> { + pair<size_t, bool> res = leafFind(leaf, k); + if (res.second) + return Base::leafData(leaf, res.first); + else + return {}; +} + +template <typename Base> +template <typename Visitor> +auto BTreeMixin<Base>::forEach(Index const& index, Key const& lower, Key const& upper, Visitor&& o) -> Key { + size_t i = indexFind(index, lower); + Key lastKey; + + if (Base::indexLevel(index) == 0) + lastKey = forEach(Base::loadLeaf(Base::indexPointer(index, i)), lower, upper, forward<Visitor>(o)); + else + lastKey = forEach(Base::loadIndex(Base::indexPointer(index, i)), lower, upper, forward<Visitor>(o)); + + if (!(lastKey < upper)) + return lastKey; + + while (i < Base::indexPointerCount(index) - 1) { + ++i; + + // We're visiting the right side of the key, so if lastKey >= + // indexKeyBefore(index, i), we have already visited this node via nextLeaf + // pointers, so skip it. + if (!(lastKey < Base::indexKeyBefore(index, i))) + continue; + + if (Base::indexLevel(index) == 0) + lastKey = forEach(Base::loadLeaf(Base::indexPointer(index, i)), lower, upper, forward<Visitor>(o)); + else + lastKey = forEach(Base::loadIndex(Base::indexPointer(index, i)), lower, upper, forward<Visitor>(o)); + + if (!(lastKey < upper)) + break; + } + + return lastKey; +} + +template <typename Base> +template <typename Visitor> +auto BTreeMixin<Base>::forEach(Leaf const& leaf, Key const& lower, Key const& upper, Visitor&& o) -> Key { + if (Base::leafElementCount(leaf) == 0) + return Key(); + + size_t lowerIndex = leafFind(leaf, lower).first; + + for (size_t i = lowerIndex; i != Base::leafElementCount(leaf); ++i) { + Key currentKey = Base::leafKey(leaf, i); + if (!(currentKey < lower)) { + if (currentKey < upper) + o(currentKey, Base::leafData(leaf, i)); + else + return currentKey; + } + } + + if (auto nextLeafPointer = Base::nextLeaf(leaf)) + return forEach(Base::loadLeaf(*nextLeafPointer), lower, upper, o); + else + return Base::leafKey(leaf, Base::leafElementCount(leaf) - 1); +} + +template <typename Base> +template <typename Visitor> +auto BTreeMixin<Base>::forAll(Index const& index, Visitor&& o) -> Key { + Key lastKey; + for (size_t i = 0; i < Base::indexPointerCount(index); ++i) { + // If we're to the right of a given key, but lastKey >= this key, then we + // must have already visited this node via nextLeaf pointers, so we can + // skip it. + if (i > 0 && !(lastKey < Base::indexKeyBefore(index, i))) + continue; + + if (Base::indexLevel(index) == 0) + lastKey = forAll(Base::loadLeaf(Base::indexPointer(index, i)), forward<Visitor>(o)); + else + lastKey = forAll(Base::loadIndex(Base::indexPointer(index, i)), forward<Visitor>(o)); + } + + return lastKey; +} + +template <typename Base> +template <typename Visitor> +auto BTreeMixin<Base>::forAll(Leaf const& leaf, Visitor&& o) -> Key { + if (Base::leafElementCount(leaf) == 0) + return Key(); + + for (size_t i = 0; i != Base::leafElementCount(leaf); ++i) { + Key currentKey = Base::leafKey(leaf, i); + o(Base::leafKey(leaf, i), Base::leafData(leaf, i)); + } + + if (auto nextLeafPointer = Base::nextLeaf(leaf)) + return forAll(Base::loadLeaf(*nextLeafPointer), forward<Visitor>(o)); + else + return Base::leafKey(leaf, Base::leafElementCount(leaf) - 1); +} + +template <typename Base> +template <typename Visitor, typename ErrorHandler> +void BTreeMixin<Base>::recoverAll(Index const& index, Visitor&& visitor, ErrorHandler&& error) { + try { + for (size_t i = 0; i < Base::indexPointerCount(index); ++i) { + if (Base::indexLevel(index) == 0) { + try { + recoverAll(Base::loadLeaf(Base::indexPointer(index, i)), forward<Visitor>(visitor), forward<ErrorHandler>(error)); + } catch (std::exception const& e) { + error("Error loading leaf node", e); + } + } else { + try { + recoverAll(Base::loadIndex(Base::indexPointer(index, i)), forward<Visitor>(visitor), forward<ErrorHandler>(error)); + } catch (std::exception const& e) { + error("Error loading index node", e); + } + } + } + } catch (std::exception const& e) { + error("Error reading index node", e); + } +} + +template <typename Base> +template <typename Visitor, typename ErrorHandler> +void BTreeMixin<Base>::recoverAll(Leaf const& leaf, Visitor&& visitor, ErrorHandler&& error) { + try { + for (size_t i = 0; i != Base::leafElementCount(leaf); ++i) { + Key currentKey = Base::leafKey(leaf, i); + visitor(Base::leafKey(leaf, i), Base::leafData(leaf, i)); + } + } catch (std::exception const& e) { + error("Error reading leaf node", e); + } +} + +template <typename Base> +void BTreeMixin<Base>::modify(Leaf& leafNode, ModifyInfo& info) { + info.state = Done; + + pair<size_t, bool> res = leafFind(leafNode, info.targetElement.key); + size_t i = res.first; + if (res.second) { + info.found = true; + Base::leafRemove(leafNode, i); + } + + // No change necessary. + if (info.action == RemoveAction && !info.found) + return; + + if (info.action == InsertAction) + Base::leafInsert(leafNode, i, info.targetElement.key, move(info.targetElement.data)); + + auto splitResult = Base::leafSplit(leafNode); + if (splitResult) { + Base::setNextLeaf(*splitResult, Base::nextLeaf(leafNode)); + info.newKey = Base::leafKey(*splitResult, 0); + info.newPointer = Base::storeLeaf(splitResult.take()); + + Base::setNextLeaf(leafNode, info.newPointer); + info.state = LeafSplit; + } else if (Base::leafNeedsShift(leafNode)) { + info.state = LeafNeedsJoin; + } else { + info.state = LeafNeedsUpdate; + } +} + +template <typename Base> +void BTreeMixin<Base>::modify(Index& indexNode, ModifyInfo& info) { + size_t i = indexFind(indexNode, info.targetElement.key); + Pointer nextPointer = Base::indexPointer(indexNode, i); + + Leaf lowerLeaf; + Index lowerIndex; + if (Base::indexLevel(indexNode) == 0) { + lowerLeaf = Base::loadLeaf(nextPointer); + modify(lowerLeaf, info); + } else { + lowerIndex = Base::loadIndex(nextPointer); + modify(lowerIndex, info); + } + + if (info.state == Done) + return; + + bool selfUpdated = false; + + size_t left = 0; + size_t right = 0; + if (i != 0 && i == Base::indexPointerCount(indexNode) - 1) { + left = i - 1; + right = i; + } else { + left = i; + right = i + 1; + } + + if (info.state == LeafNeedsJoin) { + if (Base::indexPointerCount(indexNode) < 2) { + // Don't have enough leaves to join, just do the pending update. + info.state = LeafNeedsUpdate; + } else { + Leaf leftLeaf; + Leaf rightLeaf; + + if (left == i) { + leftLeaf = lowerLeaf; + rightLeaf = Base::loadLeaf(Base::indexPointer(indexNode, right)); + } else { + leftLeaf = Base::loadLeaf(Base::indexPointer(indexNode, left)); + rightLeaf = lowerLeaf; + } + + if (!Base::leafShift(leftLeaf, rightLeaf)) { + // Leaves not modified, just do the pending update. + info.state = LeafNeedsUpdate; + } else if (Base::leafElementCount(rightLeaf) == 0) { + // Leaves merged. + Base::setNextLeaf(leftLeaf, Base::nextLeaf(rightLeaf)); + Base::deleteLeaf(move(rightLeaf)); + + // Replace two sibling pointer elements with one pointing to merged + // leaf. + if (left != 0) + Base::indexUpdateKeyBefore(indexNode, left, Base::leafKey(leftLeaf, 0)); + + Base::indexUpdatePointer(indexNode, left, Base::storeLeaf(move(leftLeaf))); + Base::indexRemoveBefore(indexNode, right); + + selfUpdated = true; + } else { + // Leaves shifted. + Base::indexUpdatePointer(indexNode, left, Base::storeLeaf(move(leftLeaf))); + + // Right leaf first key changes on shift, so always need to update + // left index node. + Base::indexUpdateKeyBefore(indexNode, right, Base::leafKey(rightLeaf, 0)); + + Base::indexUpdatePointer(indexNode, right, Base::storeLeaf(move(rightLeaf))); + + selfUpdated = true; + } + } + } + + if (info.state == IndexNeedsJoin) { + if (Base::indexPointerCount(indexNode) < 2) { + // Don't have enough indexes to join, just do the pending update. + info.state = IndexNeedsUpdate; + } else { + Index leftIndex; + Index rightIndex; + + if (left == i) { + leftIndex = lowerIndex; + rightIndex = Base::loadIndex(Base::indexPointer(indexNode, right)); + } else { + leftIndex = Base::loadIndex(Base::indexPointer(indexNode, left)); + rightIndex = lowerIndex; + } + + if (!Base::indexShift(leftIndex, getLeftKey(rightIndex), rightIndex)) { + // Indexes not modified, just do the pending update. + info.state = IndexNeedsUpdate; + + } else if (Base::indexPointerCount(rightIndex) == 0) { + // Indexes merged. + Base::deleteIndex(move(rightIndex)); + + // Replace two sibling pointer elements with one pointing to merged + // index. + if (left != 0) + Base::indexUpdateKeyBefore(indexNode, left, getLeftKey(leftIndex)); + + Base::indexUpdatePointer(indexNode, left, Base::storeIndex(move(leftIndex))); + Base::indexRemoveBefore(indexNode, right); + + selfUpdated = true; + } else { + // Indexes shifted. + Base::indexUpdatePointer(indexNode, left, Base::storeIndex(move(leftIndex))); + + // Right index first key changes on shift, so always need to update + // right index node. + Key keyForRight = getLeftKey(rightIndex); + Base::indexUpdatePointer(indexNode, right, Base::storeIndex(move(rightIndex))); + Base::indexUpdateKeyBefore(indexNode, right, keyForRight); + + selfUpdated = true; + } + } + } + + if (info.state == LeafSplit) { + Base::indexUpdatePointer(indexNode, i, Base::storeLeaf(move(lowerLeaf))); + Base::indexInsertAfter(indexNode, i, info.newKey, info.newPointer); + selfUpdated = true; + } + + if (info.state == IndexSplit) { + Base::indexUpdatePointer(indexNode, i, Base::storeIndex(move(lowerIndex))); + Base::indexInsertAfter(indexNode, i, info.newKey, info.newPointer); + selfUpdated = true; + } + + if (info.state == LeafNeedsUpdate) { + Pointer lowerLeafPointer = Base::storeLeaf(move(lowerLeaf)); + if (lowerLeafPointer != Base::indexPointer(indexNode, i)) { + Base::indexUpdatePointer(indexNode, i, lowerLeafPointer); + selfUpdated = true; + } + } + + if (info.state == IndexNeedsUpdate) { + Pointer lowerIndexPointer = Base::storeIndex(move(lowerIndex)); + if (lowerIndexPointer != Base::indexPointer(indexNode, i)) { + Base::indexUpdatePointer(indexNode, i, lowerIndexPointer); + selfUpdated = true; + } + } + + auto splitResult = Base::indexSplit(indexNode); + if (splitResult) { + info.newKey = splitResult->first; + info.newPointer = Base::storeIndex(splitResult.take().second); + info.state = IndexSplit; + selfUpdated = true; + } else if (Base::indexNeedsShift(indexNode)) { + info.state = IndexNeedsJoin; + } else if (selfUpdated) { + info.state = IndexNeedsUpdate; + } else { + info.state = Done; + } +} + +template <typename Base> +bool BTreeMixin<Base>::modify(DataElement e, ModifyAction action) { + ModifyInfo info(action, move(e)); + + Leaf lowerLeaf; + Index lowerIndex; + if (Base::rootIsLeaf()) { + lowerLeaf = Base::loadLeaf(Base::rootPointer()); + modify(lowerLeaf, info); + } else { + lowerIndex = Base::loadIndex(Base::rootPointer()); + modify(lowerIndex, info); + } + + if (info.state == IndexNeedsJoin) { + if (Base::indexPointerCount(lowerIndex) == 1) { + // If root index has single pointer, then make that the new root. + + // release index first (to support the common use case of delaying + // removes until setNewRoot) + Pointer pointer = Base::indexPointer(lowerIndex, 0); + size_t level = Base::indexLevel(lowerIndex); + Base::deleteIndex(move(lowerIndex)); + Base::setNewRoot(pointer, level == 0); + } else { + // Else just update. + info.state = IndexNeedsUpdate; + } + } + + if (info.state == LeafNeedsJoin) { + // Ignore NeedsJoin on LeafNode root, just update. + info.state = LeafNeedsUpdate; + } + + if (info.state == LeafSplit || info.state == IndexSplit) { + Index newRoot; + if (info.state == IndexSplit) { + auto rootIndexLevel = Base::indexLevel(lowerIndex) + 1; + newRoot = Base::createIndex(Base::storeIndex(move(lowerIndex))); + Base::setIndexLevel(newRoot, rootIndexLevel); + } else { + newRoot = Base::createIndex(Base::storeLeaf(move(lowerLeaf))); + Base::setIndexLevel(newRoot, 0); + } + Base::indexInsertAfter(newRoot, 0, info.newKey, info.newPointer); + Base::setNewRoot(Base::storeIndex(move(newRoot)), false); + } + + if (info.state == IndexNeedsUpdate) { + Pointer newRootPointer = Base::storeIndex(move(lowerIndex)); + if (newRootPointer != Base::rootPointer()) + Base::setNewRoot(newRootPointer, false); + } + + if (info.state == LeafNeedsUpdate) { + Pointer newRootPointer = Base::storeLeaf(move(lowerLeaf)); + if (newRootPointer != Base::rootPointer()) + Base::setNewRoot(newRootPointer, true); + } + + return info.found; +} + +template <typename Base> +auto BTreeMixin<Base>::getLeftKey(Index const& index) -> Key { + if (Base::indexLevel(index) == 0) { + Leaf leaf = Base::loadLeaf(Base::indexPointer(index, 0)); + return Base::leafKey(leaf, 0); + } else { + return getLeftKey(Base::loadIndex(Base::indexPointer(index, 0))); + } +} + +template <typename Base> +template <typename Visitor> +void BTreeMixin<Base>::forAllNodes(Index const& index, Visitor&& visitor) { + if (!visitor(index)) + return; + + for (size_t i = 0; i < Base::indexPointerCount(index); ++i) { + if (Base::indexLevel(index) != 0) { + forAllNodes(Base::loadIndex(Base::indexPointer(index, i)), forward<Visitor>(visitor)); + } else { + if (!visitor(Base::loadLeaf(Base::indexPointer(index, i)))) + return; + } + } +} + +template <typename Base> +pair<size_t, bool> BTreeMixin<Base>::leafFind(Leaf const& leaf, Key const& key) { + // Return lower bound binary search result. + size_t size = Base::leafElementCount(leaf); + if (size == 0) + return {0, false}; + + size_t len = size; + size_t first = 0; + size_t middle = 0; + size_t half; + while (len > 0) { + half = len / 2; + middle = first + half; + if (Base::leafKey(leaf, middle) < key) { + first = middle + 1; + len = len - half - 1; + } else { + len = half; + } + } + return make_pair(first, first < size && !(key < Base::leafKey(leaf, first))); +} + +template <typename Base> +size_t BTreeMixin<Base>::indexFind(Index const& index, Key const& key) { + // Return upper bound binary search result of range [1, size]; + size_t size = Base::indexPointerCount(index); + if (size == 0) + return 0; + + size_t len = size - 1; + size_t first = 1; + size_t middle = 1; + size_t half; + while (len > 0) { + half = len / 2; + middle = first + half; + if (key < Base::indexKeyBefore(index, middle)) { + len = half; + } else { + first = middle + 1; + len = len - half - 1; + } + } + return first - 1; +} + +} + +#endif diff --git a/source/core/StarBTreeDatabase.cpp b/source/core/StarBTreeDatabase.cpp new file mode 100644 index 0000000..175de9e --- /dev/null +++ b/source/core/StarBTreeDatabase.cpp @@ -0,0 +1,1188 @@ +#include "StarBTreeDatabase.hpp" +#include "StarSha256.hpp" +#include "StarVlqEncoding.hpp" + +namespace Star { + +BTreeDatabase::BTreeDatabase() { + m_impl.parent = this; + m_open = false; + m_deviceSize = 0; + m_blockSize = 2048; + m_headFreeIndexBlock = InvalidBlockIndex; + m_keySize = 0; + m_autoCommit = true; + m_indexCache.setMaxSize(64); + m_root = InvalidBlockIndex; + m_rootIsLeaf = false; + m_usingAltRoot = false; +} + +BTreeDatabase::BTreeDatabase(String const& contentIdentifier, size_t keySize) + : BTreeDatabase() { + setContentIdentifier(contentIdentifier); + setKeySize(keySize); +} + +BTreeDatabase::~BTreeDatabase() { + close(); +} + +uint32_t BTreeDatabase::blockSize() const { + ReadLocker readLocker(m_lock); + return m_blockSize; +} + +void BTreeDatabase::setBlockSize(uint32_t blockSize) { + WriteLocker writeLocker(m_lock); + checkIfOpen("setBlockSize", false); + m_blockSize = blockSize; +} + +uint32_t BTreeDatabase::keySize() const { + ReadLocker readLocker(m_lock); + return m_keySize; +} + +void BTreeDatabase::setKeySize(uint32_t keySize) { + WriteLocker writeLocker(m_lock); + checkIfOpen("setKeySize", false); + m_keySize = keySize; +} + +String BTreeDatabase::contentIdentifier() const { + ReadLocker readLocker(m_lock); + return m_contentIdentifier; +} + +void BTreeDatabase::setContentIdentifier(String contentIdentifier) { + WriteLocker writeLocker(m_lock); + checkIfOpen("setContentIdentifier", false); + m_contentIdentifier = move(contentIdentifier); +} + +uint32_t BTreeDatabase::indexCacheSize() const { + SpinLocker lock(m_indexCacheSpinLock); + return m_indexCache.maxSize(); +} + +void BTreeDatabase::setIndexCacheSize(uint32_t indexCacheSize) { + SpinLocker lock(m_indexCacheSpinLock); + m_indexCache.setMaxSize(indexCacheSize); +} + +bool BTreeDatabase::autoCommit() const { + ReadLocker readLocker(m_lock); + return m_autoCommit; +} + +void BTreeDatabase::setAutoCommit(bool autoCommit) { + WriteLocker writeLocker(m_lock); + m_autoCommit = autoCommit; + if (m_autoCommit) + doCommit(); +} + +IODevicePtr BTreeDatabase::ioDevice() const { + ReadLocker readLocker(m_lock); + return m_device; +} + +void BTreeDatabase::setIODevice(IODevicePtr device) { + WriteLocker writeLocker(m_lock); + checkIfOpen("setIODevice", false); + m_device = move(device); +} + +bool BTreeDatabase::isOpen() const { + ReadLocker readLocker(m_lock); + return m_open; +} + +bool BTreeDatabase::open() { + WriteLocker writeLocker(m_lock); + if (m_open) + return false; + + if (!m_device) + throw DBException("BlockStorage::open called with no IODevice set"); + + if (!m_device->isOpen()) + m_device->open(IOMode::ReadWrite); + + m_open = true; + + if (m_device->size() > 0) { + DataStreamIODevice ds(m_device); + ds.seek(0); + + auto magic = ds.readBytes(VersionMagicSize); + if (magic != ByteArray::fromCString(VersionMagic)) + throw DBException("Device is not a valid BTreeDatabase file"); + + m_blockSize = ds.read<uint32_t>(); + + auto contentIdentifier = ds.readBytes(ContentIdentifierStringSize); + contentIdentifier.appendByte('\0'); + m_contentIdentifier = String(contentIdentifier.ptr()); + m_keySize = ds.read<uint32_t>(); + + readRoot(); + + if (m_device->isWritable()) + m_device->resize(m_deviceSize); + + return false; + + } else { + m_deviceSize = HeaderSize; + m_device->resize(m_deviceSize); + m_headFreeIndexBlock = InvalidBlockIndex; + + DataStreamIODevice ds(m_device); + ds.seek(0); + + ds.writeData(VersionMagic, VersionMagicSize); + ds.write<uint32_t>(m_blockSize); + + if (m_contentIdentifier.empty()) + throw DBException("Opening new database and no content identifier set!"); + + if (m_contentIdentifier.utf8Size() > ContentIdentifierStringSize) + throw DBException("contentIdentifier in BTreeDatabase implementation is greater than maximum identifier length"); + if (m_keySize == 0) + throw DBException("key size is not set opening a new BTreeDatabase"); + + ByteArray contentIdentifier = m_contentIdentifier.utf8Bytes(); + contentIdentifier.resize(ContentIdentifierStringSize, 0); + ds.writeBytes(contentIdentifier); + ds.write(m_keySize); + + m_impl.createNewRoot(); + doCommit(); + + return true; + } +} + +bool BTreeDatabase::contains(ByteArray const& k) { + ReadLocker readLocker(m_lock); + checkKeySize(k); + return m_impl.contains(k); +} + +Maybe<ByteArray> BTreeDatabase::find(ByteArray const& k) { + ReadLocker readLocker(m_lock); + checkKeySize(k); + return m_impl.find(k); +} + +List<pair<ByteArray, ByteArray>> BTreeDatabase::find(ByteArray const& lower, ByteArray const& upper) { + ReadLocker readLocker(m_lock); + checkKeySize(lower); + checkKeySize(upper); + return m_impl.find(lower, upper); +} + +void BTreeDatabase::forEach(ByteArray const& lower, ByteArray const& upper, function<void(ByteArray, ByteArray)> v) { + ReadLocker readLocker(m_lock); + checkKeySize(lower); + checkKeySize(upper); + m_impl.forEach(lower, upper, move(v)); +} + +void BTreeDatabase::forAll(function<void(ByteArray, ByteArray)> v) { + ReadLocker readLocker(m_lock); + m_impl.forAll(move(v)); +} + +bool BTreeDatabase::insert(ByteArray const& k, ByteArray const& data) { + WriteLocker writeLocker(m_lock); + checkKeySize(k); + return m_impl.insert(move(k), move(data)); +} + +bool BTreeDatabase::remove(ByteArray const& k) { + WriteLocker writeLocker(m_lock); + checkKeySize(k); + return m_impl.remove(k); +} + +uint64_t BTreeDatabase::recordCount() { + ReadLocker readLocker(m_lock); + return m_impl.recordCount(); +} + +uint8_t BTreeDatabase::indexLevels() { + ReadLocker readLocker(m_lock); + return m_impl.indexLevels(); +} + +uint32_t BTreeDatabase::totalBlockCount() { + ReadLocker readLocker(m_lock); + checkIfOpen("totalBlockCount", true); + return (m_device->size() - HeaderSize) / m_blockSize; +} + +uint32_t BTreeDatabase::freeBlockCount() { + ReadLocker readLocker(m_lock); + checkIfOpen("freeBlockCount", true); + + // Go through every FreeIndexBlock in the chain and count all of the tracked + // free blocks. + BlockIndex count = 0; + BlockIndex indexBlockIndex = m_headFreeIndexBlock; + while (indexBlockIndex != InvalidBlockIndex) { + FreeIndexBlock indexBlock = readFreeIndexBlock(indexBlockIndex); + count += 1 + indexBlock.freeBlocks.size(); + indexBlockIndex = indexBlock.nextFreeBlock; + } + + count += m_availableBlocks.size() + m_pendingFree.size(); + + // Include untracked blocks at the end of the file in the free count. + count += (m_device->size() - m_deviceSize) / m_blockSize; + + return count; +} + +uint32_t BTreeDatabase::indexBlockCount() { + ReadLocker readLocker(m_lock); + checkIfOpen("indexBlockCount", true); + // Indexes are simply one index per block + return m_impl.indexCount(); +} + +uint32_t BTreeDatabase::leafBlockCount() { + WriteLocker writeLocker(m_lock); + checkIfOpen("leafBlockCount", true); + + struct LeafBlocksVisitor { + bool operator()(shared_ptr<IndexNode> const&) { + return true; + } + + bool operator()(shared_ptr<LeafNode> const& leaf) { + leafBlockCount += 1 + parent->leafTailBlocks(leaf->self).size(); + return true; + } + + BTreeDatabase* parent; + BlockIndex leafBlockCount = 0; + }; + + LeafBlocksVisitor visitor; + visitor.parent = this; + m_impl.forAllNodes(visitor); + + return visitor.leafBlockCount; +} + +void BTreeDatabase::commit() { + WriteLocker writeLocker(m_lock); + doCommit(); +} + +void BTreeDatabase::rollback() { + WriteLocker writeLocker(m_lock); + + m_availableBlocks.clear(); + m_indexCache.clear(); + m_uncommitted.clear(); + m_pendingFree.clear(); + + readRoot(); + + if (m_device->isWritable()) + m_device->resize(m_deviceSize); +} + +void BTreeDatabase::close(bool closeDevice) { + WriteLocker writeLocker(m_lock); + if (m_open) { + doCommit(); + + m_indexCache.clear(); + + m_open = false; + if (closeDevice && m_device && m_device->isOpen()) + m_device->close(); + } +} + +BTreeDatabase::BlockIndex const BTreeDatabase::InvalidBlockIndex; +uint32_t const BTreeDatabase::HeaderSize; +char const* const BTreeDatabase::VersionMagic = "BTreeDB5"; +uint32_t const BTreeDatabase::VersionMagicSize; +char const* const BTreeDatabase::IndexMagic = "II"; +char const* const BTreeDatabase::LeafMagic = "LL"; +char const* const BTreeDatabase::FreeIndexMagic = "FF"; +size_t const BTreeDatabase::BTreeRootSelectorBit; +size_t const BTreeDatabase::BTreeRootInfoStart; +size_t const BTreeDatabase::BTreeRootInfoSize; + +size_t BTreeDatabase::IndexNode::pointerCount() const { + // If no begin pointer is set then the index is simply uninitialized. + if (!beginPointer) + return 0; + else + return pointers.size() + 1; +} + +auto BTreeDatabase::IndexNode::pointer(size_t i) const -> BlockIndex { + if (i == 0) + return *beginPointer; + else + return pointers.at(i - 1).pointer; +} + +void BTreeDatabase::IndexNode::updatePointer(size_t i, BlockIndex p) { + if (i == 0) + *beginPointer = p; + else + pointers.at(i - 1).pointer = p; +} + +ByteArray const& BTreeDatabase::IndexNode::keyBefore(size_t i) const { + return pointers.at(i - 1).key; +} + +void BTreeDatabase::IndexNode::updateKeyBefore(size_t i, ByteArray k) { + pointers.at(i - 1).key = k; +} + +void BTreeDatabase::IndexNode::removeBefore(size_t i) { + if (i == 0) { + beginPointer = pointers.at(0).pointer; + pointers.eraseAt(0); + } else { + pointers.eraseAt(i - 1); + } +} + +void BTreeDatabase::IndexNode::insertAfter(size_t i, ByteArray k, BlockIndex p) { + pointers.insertAt(i, Element{k, p}); +} + +uint8_t BTreeDatabase::IndexNode::indexLevel() const { + return level; +} + +void BTreeDatabase::IndexNode::setIndexLevel(uint8_t indexLevel) { + level = indexLevel; +} + +void BTreeDatabase::IndexNode::shiftLeft(ByteArray const& mid, IndexNode& right, size_t count) { + count = std::min(right.pointerCount(), count); + + if (count == 0) + return; + + pointers.append(Element{mid, *right.beginPointer}); + + ElementList::iterator s = right.pointers.begin(); + std::advance(s, count - 1); + pointers.insert(pointers.end(), right.pointers.begin(), s); + + right.pointers.erase(right.pointers.begin(), s); + if (right.pointers.size() != 0) { + right.beginPointer = right.pointers.at(0).pointer; + right.pointers.eraseAt(0); + } else { + right.beginPointer.reset(); + } +} + +void BTreeDatabase::IndexNode::shiftRight(ByteArray const& mid, IndexNode& left, size_t count) { + count = std::min(left.pointerCount(), count); + + if (count == 0) + return; + --count; + + pointers.insert(pointers.begin(), Element{mid, *beginPointer}); + + ElementList::iterator s = left.pointers.begin(); + std::advance(s, left.pointers.size() - count); + pointers.insert(pointers.begin(), s, left.pointers.end()); + + left.pointers.erase(s, left.pointers.end()); + if (left.pointers.size() != 0) { + beginPointer = left.pointers.at(left.pointers.size() - 1).pointer; + left.pointers.eraseAt(left.pointers.size() - 1); + } else { + beginPointer = left.beginPointer.take(); + } +} + +ByteArray BTreeDatabase::IndexNode::split(IndexNode& right, size_t i) { + ElementList::iterator s = pointers.begin(); + std::advance(s, i - 1); + + right.beginPointer = s->pointer; + ByteArray midKey = s->key; + right.level = level; + ++s; + + right.pointers.insert(right.pointers.begin(), s, pointers.end()); + --s; + + pointers.erase(s, pointers.end()); + + return midKey; +} + +size_t BTreeDatabase::LeafNode::count() const { + return elements.size(); +} + +ByteArray const& BTreeDatabase::LeafNode::key(size_t i) const { + return elements.at(i).key; +} + +ByteArray const& BTreeDatabase::LeafNode::data(size_t i) const { + return elements.at(i).data; +} + +void BTreeDatabase::LeafNode::insert(size_t i, ByteArray k, ByteArray d) { + elements.insertAt(i, Element{move(k), move(d)}); +} + +void BTreeDatabase::LeafNode::remove(size_t i) { + elements.eraseAt(i); +} + +void BTreeDatabase::LeafNode::shiftLeft(LeafNode& right, size_t count) { + count = std::min(right.count(), count); + + if (count == 0) + return; + + ElementList::iterator s = right.elements.begin(); + std::advance(s, count); + + elements.insert(elements.end(), right.elements.begin(), s); + right.elements.erase(right.elements.begin(), s); +} + +void BTreeDatabase::LeafNode::shiftRight(LeafNode& left, size_t count) { + count = std::min(left.count(), count); + + if (count == 0) + return; + + ElementList::iterator s = left.elements.begin(); + std::advance(s, left.elements.size() - count); + + elements.insert(elements.begin(), s, left.elements.end()); + left.elements.erase(s, left.elements.end()); +} + +void BTreeDatabase::LeafNode::split(LeafNode& right, size_t i) { + ElementList::iterator s = elements.begin(); + std::advance(s, i); + + right.elements.insert(right.elements.begin(), s, elements.end()); + elements.erase(s, elements.end()); +} + +auto BTreeDatabase::BTreeImpl::rootPointer() -> Pointer { + return parent->m_root; +} + +bool BTreeDatabase::BTreeImpl::rootIsLeaf() { + return parent->m_rootIsLeaf; +} + +void BTreeDatabase::BTreeImpl::setNewRoot(Pointer pointer, bool isLeaf) { + parent->m_root = pointer; + parent->m_rootIsLeaf = isLeaf; + + if (parent->m_autoCommit) + parent->doCommit(); +} + +auto BTreeDatabase::BTreeImpl::createIndex(Pointer beginPointer) -> Index { + auto index = make_shared<IndexNode>(); + index->self = InvalidBlockIndex; + index->level = 0; + index->beginPointer = beginPointer; + return index; +} + +auto BTreeDatabase::BTreeImpl::loadIndex(Pointer pointer) -> Index { + SpinLocker lock(parent->m_indexCacheSpinLock); + if (auto index = parent->m_indexCache.ptr(pointer)) + return *index; + lock.unlock(); + + auto index = make_shared<IndexNode>(); + + DataStreamBuffer buffer(parent->readBlock(pointer)); + + if (buffer.readBytes(2) != ByteArray(IndexMagic, 2)) + throw DBException("Error, incorrect index block signature."); + + index->self = pointer; + + index->level = buffer.read<uint8_t>(); + uint32_t s = buffer.read<uint32_t>(); + index->beginPointer = buffer.read<BlockIndex>(); + index->pointers.resize(s); + for (uint32_t i = 0; i < s; ++i) { + auto& e = index->pointers[i]; + e.key =buffer.readBytes(parent->m_keySize); + e.pointer = buffer.read<BlockIndex>(); + } + + lock.lock(); + parent->m_indexCache.set(pointer, index); + return index; +} + +bool BTreeDatabase::BTreeImpl::indexNeedsShift(Index const& index) { + return index->pointerCount() < (parent->maxIndexPointers() + 1) / 2; +} + +bool BTreeDatabase::BTreeImpl::indexShift(Index const& left, Key const& mid, Index const& right) { + if (left->pointerCount() + right->pointerCount() <= parent->maxIndexPointers()) { + left->shiftLeft(mid, *right, right->pointerCount()); + return true; + } else { + if (indexNeedsShift(right)) { + right->shiftRight(mid, *left, 1); + return true; + } else if (indexNeedsShift(left)) { + left->shiftLeft(mid, *right, 1); + return true; + } else { + return false; + } + } +} + +auto BTreeDatabase::BTreeImpl::indexSplit(Index const& index) -> Maybe<pair<Key, Index>> { + if (index->pointerCount() <= parent->maxIndexPointers()) + return {}; + + auto right = make_shared<IndexNode>(); + right->self = InvalidBlockIndex; + Key k = index->split(*right, (index->pointerCount() + 1) / 2); + return make_pair(k, right); +} + +auto BTreeDatabase::BTreeImpl::storeIndex(Index index) -> Pointer { + if (index->self != InvalidBlockIndex) { + if (!parent->m_uncommitted.contains(index->self)) { + parent->freeBlock(index->self); + parent->m_indexCache.remove(index->self); + index->self = InvalidBlockIndex; + } + } + + if (index->self == InvalidBlockIndex) + index->self = parent->reserveBlock(); + + DataStreamBuffer buffer(parent->m_blockSize); + buffer.writeData(IndexMagic, 2); + + buffer.write<uint8_t>(index->level); + buffer.write<uint32_t>(index->pointers.size()); + buffer.write<BlockIndex>(*index->beginPointer); + for (auto i = index->pointers.begin(); i != index->pointers.end(); ++i) { + starAssert(i->key.size() == parent->m_keySize); + buffer.writeBytes(i->key); + buffer.write<BlockIndex>(i->pointer); + } + + parent->updateBlock(index->self, buffer.data()); + + parent->m_indexCache.set(index->self, index); + return index->self; +} + +void BTreeDatabase::BTreeImpl::deleteIndex(Index index) { + parent->m_indexCache.remove(index->self); + parent->freeBlock(index->self); +} + +auto BTreeDatabase::BTreeImpl::createLeaf() -> Leaf { + auto leaf = make_shared<LeafNode>(); + leaf->self = InvalidBlockIndex; + return leaf; +} + +auto BTreeDatabase::BTreeImpl::loadLeaf(Pointer pointer) -> Leaf { + auto leaf = make_shared<LeafNode>(); + leaf->self = pointer; + + BlockIndex currentLeafBlock = leaf->self; + DataStreamBuffer leafBuffer; + leafBuffer.reset(parent->m_blockSize); + parent->readBlock(currentLeafBlock, 0, leafBuffer.ptr(), parent->m_blockSize); + + if (leafBuffer.readBytes(2) != ByteArray(LeafMagic, 2)) + throw DBException("Error, incorrect leaf block signature."); + + DataStreamFunctions leafInput([&](char* data, size_t len) -> size_t { + size_t pos = 0; + size_t left = len; + + while (left > 0) { + if (leafBuffer.pos() + left < parent->m_blockSize - sizeof(BlockIndex)) { + leafBuffer.readData(data + pos, left); + pos += left; + left = 0; + } else { + size_t toRead = parent->m_blockSize - sizeof(BlockIndex) - leafBuffer.pos(); + leafBuffer.readData(data + pos, toRead); + pos += toRead; + left -= toRead; + } + + if (leafBuffer.pos() == (parent->m_blockSize - sizeof(BlockIndex)) && left > 0) { + currentLeafBlock = leafBuffer.read<BlockIndex>(); + if (currentLeafBlock != InvalidBlockIndex) { + leafBuffer.reset(parent->m_blockSize); + parent->readBlock(currentLeafBlock, 0, leafBuffer.ptr(), parent->m_blockSize); + + if (leafBuffer.readBytes(2) != ByteArray(LeafMagic, 2)) + throw DBException("Error, incorrect leaf block signature."); + + } else { + throw DBException("Leaf read off end of Leaf list."); + } + } + } + + return len; + }, {}); + + uint32_t count = leafInput.read<uint32_t>(); + leaf->elements.resize(count); + for (uint32_t i = 0; i < count; ++i) { + auto& element = leaf->elements[i]; + element.key = leafInput.readBytes(parent->m_keySize); + element.data = leafInput.read<ByteArray>(); + } + + return leaf; +} + +bool BTreeDatabase::BTreeImpl::leafNeedsShift(Leaf const& l) { + return parent->leafSize(l) < parent->m_blockSize / 2; +} + +bool BTreeDatabase::BTreeImpl::leafShift(Leaf& left, Leaf& right) { + if (left->count() == 0) { + left->shiftLeft(*right, right->count()); + return true; + } + + if (right->count() == 0) + return true; + + uint32_t leftSize = parent->leafSize(left); + uint32_t rightSize = parent->leafSize(right); + if (leftSize + rightSize < parent->m_blockSize) { + left->shiftLeft(*right, right->count()); + return true; + } + + // TODO: Shifting algorithm is bad, could potentially want to shift more + // than one element here. + uint32_t rightBeginSize = parent->m_keySize + parent->dataSize(right->elements[0].data); + uint32_t leftEndSize = parent->m_keySize + parent->dataSize(left->elements[left->elements.size() - 1].data); + if (leftSize < rightSize - rightBeginSize && leftSize + rightBeginSize < parent->m_blockSize) { + left->shiftLeft(*right, 1); + return true; + } else if (rightSize < leftSize - leftEndSize && rightSize + leftEndSize < parent->m_blockSize) { + right->shiftRight(*left, 1); + return true; + } + + return false; +} + +auto BTreeDatabase::BTreeImpl::leafSplit(Leaf& leaf) -> Maybe<Leaf> { + if (leaf->elements.size() < 2) + return {}; + + uint32_t size = 6; + bool boundaryFound = false; + uint32_t boundary = 0; + for (uint32_t i = 0; i < leaf->elements.size(); ++i) { + size += parent->m_keySize; + size += parent->dataSize(leaf->elements[i].data); + if (size > parent->m_blockSize - sizeof(BlockIndex) && !boundaryFound) { + boundary = i; + boundaryFound = true; + } + } + if (boundary == 0) + boundary = 1; + + if (size < parent->m_blockSize * 2 - 2 * sizeof(BlockIndex) - 4) { + return {}; + } else { + auto right = make_shared<LeafNode>(); + right->self = InvalidBlockIndex; + leaf->split(*right, boundary); + return right; + } +} + +auto BTreeDatabase::BTreeImpl::storeLeaf(Leaf leaf) -> Pointer { + if (leaf->self != InvalidBlockIndex) { + List<BlockIndex> tailBlocks = parent->leafTailBlocks(leaf->self); + for (uint32_t i = 0; i < tailBlocks.size(); ++i) + parent->freeBlock(tailBlocks[i]); + + if (!parent->m_uncommitted.contains(leaf->self)) { + parent->freeBlock(leaf->self); + leaf->self = InvalidBlockIndex; + } + } + + if (leaf->self == InvalidBlockIndex) + leaf->self = parent->reserveBlock(); + + BlockIndex currentLeafBlock = leaf->self; + DataStreamBuffer leafBuffer; + leafBuffer.reset(parent->m_blockSize); + leafBuffer.writeData(LeafMagic, 2); + + DataStreamFunctions leafOutput({}, [&](char const* data, size_t len) -> size_t { + size_t pos = 0; + size_t left = len; + + while (true) { + size_t toWrite = left; + if (toWrite > parent->m_blockSize - leafBuffer.pos() - sizeof(BlockIndex)) + toWrite = parent->m_blockSize - leafBuffer.pos() - sizeof(BlockIndex); + + if (toWrite != 0) { + leafBuffer.writeData(data + pos, toWrite); + left -= toWrite; + pos += toWrite; + } + + if (left == 0) + break; + + if (leafBuffer.pos() == (parent->m_blockSize - sizeof(BlockIndex))) { + BlockIndex nextBlock = parent->reserveBlock(); + leafBuffer.write<BlockIndex>(nextBlock); + parent->updateBlock(currentLeafBlock, leafBuffer.data()); + currentLeafBlock = nextBlock; + leafBuffer.reset(parent->m_blockSize); + leafBuffer.writeData(LeafMagic, 2); + } + } + + return len; + }); + + leafOutput.write<uint32_t>(leaf->elements.size()); + + for (LeafNode::ElementList::iterator i = leaf->elements.begin(); i != leaf->elements.end(); ++i) { + starAssert(i->key.size() == parent->m_keySize); + leafOutput.writeBytes(i->key); + leafOutput.write(i->data); + } + + leafBuffer.seek(parent->m_blockSize - sizeof(BlockIndex)); + leafBuffer.write<BlockIndex>(InvalidBlockIndex); + parent->updateBlock(currentLeafBlock, leafBuffer.data()); + + return leaf->self; +} + +void BTreeDatabase::BTreeImpl::deleteLeaf(Leaf leaf) { + List<BlockIndex> tailBlocks = parent->leafTailBlocks(leaf->self); + for (uint32_t i = 0; i < tailBlocks.size(); ++i) + parent->freeBlock(tailBlocks[i]); + + parent->freeBlock(leaf->self); +} + +size_t BTreeDatabase::BTreeImpl::indexPointerCount(Index const& index) { + return index->pointerCount(); +} + +auto BTreeDatabase::BTreeImpl::indexPointer(Index const& index, size_t i) -> Pointer { + return index->pointer(i); +} + +void BTreeDatabase::BTreeImpl::indexUpdatePointer(Index& index, size_t i, Pointer p) { + index->updatePointer(i, p); +} + +auto BTreeDatabase::BTreeImpl::indexKeyBefore(Index const& index, size_t i) -> Key { + return index->keyBefore(i); +} + +void BTreeDatabase::BTreeImpl::indexUpdateKeyBefore(Index& index, size_t i, Key k) { + index->updateKeyBefore(i, k); +} + +void BTreeDatabase::BTreeImpl::indexRemoveBefore(Index& index, size_t i) { + index->removeBefore(i); +} + +void BTreeDatabase::BTreeImpl::indexInsertAfter(Index& index, size_t i, Key k, Pointer p) { + index->insertAfter(i, k, p); +} + +size_t BTreeDatabase::BTreeImpl::indexLevel(Index const& index) { + return index->indexLevel(); +} + +void BTreeDatabase::BTreeImpl::setIndexLevel(Index& index, size_t indexLevel) { + index->setIndexLevel(indexLevel); +} + +size_t BTreeDatabase::BTreeImpl::leafElementCount(Leaf const& leaf) { + return leaf->count(); +} + +auto BTreeDatabase::BTreeImpl::leafKey(Leaf const& leaf, size_t i) -> Key { + return leaf->key(i); +} + +auto BTreeDatabase::BTreeImpl::leafData(Leaf const& leaf, size_t i) -> Data { + return leaf->data(i); +} + +void BTreeDatabase::BTreeImpl::leafInsert(Leaf& leaf, size_t i, Key k, Data d) { + leaf->insert(i, move(k), move(d)); +} + +void BTreeDatabase::BTreeImpl::leafRemove(Leaf& leaf, size_t i) { + leaf->remove(i); +} + +auto BTreeDatabase::BTreeImpl::nextLeaf(Leaf const&) -> Maybe<Pointer> { + return {}; +} + +void BTreeDatabase::BTreeImpl::setNextLeaf(Leaf&, Maybe<Pointer>) {} + +void BTreeDatabase::readBlock(BlockIndex blockIndex, size_t blockOffset, char* block, size_t size) const { + checkBlockIndex(blockIndex); + rawReadBlock(blockIndex, blockOffset, block, size); +} + +ByteArray BTreeDatabase::readBlock(BlockIndex blockIndex) const { + ByteArray block(m_blockSize, 0); + readBlock(blockIndex, 0, block.ptr(), m_blockSize); + return block; +} + +void BTreeDatabase::updateBlock(BlockIndex blockIndex, ByteArray const& block) { + checkBlockIndex(blockIndex); + rawWriteBlock(blockIndex, 0, block.ptr(), block.size()); +} + +void BTreeDatabase::rawReadBlock(BlockIndex blockIndex, size_t blockOffset, char* block, size_t size) const { + if (blockOffset > m_blockSize || size > m_blockSize - blockOffset) + throw DBException::format("Read past end of block, offset: %s size %s", blockOffset, size); + + if (size <= 0) + return; + + m_device->readFullAbsolute(HeaderSize + blockIndex * (StreamOffset)m_blockSize + blockOffset, block, size); +} + +void BTreeDatabase::rawWriteBlock(BlockIndex blockIndex, size_t blockOffset, char const* block, size_t size) const { + if (blockOffset > m_blockSize || size > m_blockSize - blockOffset) + throw DBException::format("Write past end of block, offset: %s size %s", blockOffset, size); + + if (size <= 0) + return; + + m_device->writeFullAbsolute(HeaderSize + blockIndex * (StreamOffset)m_blockSize + blockOffset, block, size); +} + +auto BTreeDatabase::readFreeIndexBlock(BlockIndex blockIndex) -> FreeIndexBlock { + checkBlockIndex(blockIndex); + + ByteArray magic(2, 0); + rawReadBlock(blockIndex, 0, magic.ptr(), 2); + if (magic != ByteArray(FreeIndexMagic, 2)) + throw DBException::format("Internal exception! block %s missing free index block marker!", blockIndex); + + FreeIndexBlock freeIndexBlock; + DataStreamBuffer buffer(max(sizeof(BlockIndex), (size_t)4)); + + rawReadBlock(blockIndex, 2, buffer.ptr(), sizeof(BlockIndex)); + buffer.seek(0); + freeIndexBlock.nextFreeBlock = buffer.read<BlockIndex>(); + + rawReadBlock(blockIndex, 2 + sizeof(BlockIndex), buffer.ptr(), 4); + buffer.seek(0); + size_t numFree = buffer.read<uint32_t>(); + + for (size_t i = 0; i < numFree; ++i) { + rawReadBlock(blockIndex, 6 + sizeof(BlockIndex) + sizeof(BlockIndex) * i, buffer.ptr(), sizeof(BlockIndex)); + buffer.seek(0); + freeIndexBlock.freeBlocks.append(buffer.read<BlockIndex>()); + } + + return freeIndexBlock; +} + +void BTreeDatabase::writeFreeIndexBlock(BlockIndex blockIndex, FreeIndexBlock indexBlock) { + checkBlockIndex(blockIndex); + + rawWriteBlock(blockIndex, 0, FreeIndexMagic, 2); + DataStreamBuffer buffer(max(sizeof(BlockIndex), (size_t)4)); + + buffer.seek(0); + buffer.write<BlockIndex>(indexBlock.nextFreeBlock); + rawWriteBlock(blockIndex, 2, buffer.ptr(), sizeof(BlockIndex)); + + buffer.seek(0); + buffer.write<uint32_t>(indexBlock.freeBlocks.size()); + rawWriteBlock(blockIndex, 2 + sizeof(BlockIndex), buffer.ptr(), 4); + + for (size_t i = 0; i < indexBlock.freeBlocks.size(); ++i) { + buffer.seek(0); + buffer.write<BlockIndex>(indexBlock.freeBlocks[i]); + rawWriteBlock(blockIndex, 6 + sizeof(BlockIndex) + sizeof(BlockIndex) * i, buffer.ptr(), sizeof(BlockIndex)); + } +} + +uint32_t BTreeDatabase::leafSize(shared_ptr<LeafNode> const& leaf) const { + size_t s = 6; + for (LeafNode::ElementList::iterator i = leaf->elements.begin(); i != leaf->elements.end(); ++i) { + s += m_keySize; + s += dataSize(i->data); + } + return s; +} + +uint32_t BTreeDatabase::maxIndexPointers() const { + // 2 for magic, 1 byte for level, sizeof(BlockIndex) for beginPointer, 4 + // for size. + return (m_blockSize - 2 - 1 - sizeof(BlockIndex) - 4) / (m_keySize + sizeof(BlockIndex)) + 1; +} + +uint32_t BTreeDatabase::dataSize(ByteArray const& d) const { + return vlqUSize(d.size()) + d.size(); +} + +auto BTreeDatabase::leafTailBlocks(BlockIndex leafPointer) -> List<BlockIndex> { + List<BlockIndex> tailBlocks; + DataStreamBuffer pointerBuffer(sizeof(BlockIndex)); + while (leafPointer != InvalidBlockIndex) { + readBlock(leafPointer, m_blockSize - sizeof(BlockIndex), pointerBuffer.ptr(), sizeof(BlockIndex)); + pointerBuffer.seek(0); + leafPointer = pointerBuffer.read<BlockIndex>(); + if (leafPointer != InvalidBlockIndex) + tailBlocks.append(leafPointer); + } + return tailBlocks; +} + +void BTreeDatabase::freeBlock(BlockIndex b) { + if (m_uncommitted.contains(b)) { + m_uncommitted.remove(b); + m_availableBlocks.add(b); + } else { + m_pendingFree.append(b); + } +} + +auto BTreeDatabase::reserveBlock() -> BlockIndex { + if (m_availableBlocks.empty()) { + if (m_headFreeIndexBlock != InvalidBlockIndex) { + // If available, make available all the blocks in the first free index + // block. + FreeIndexBlock indexBlock = readFreeIndexBlock(m_headFreeIndexBlock); + for (auto const& b : indexBlock.freeBlocks) + m_availableBlocks.add(b); + // We cannot make available the block itself, because we must maintain + // atomic consistency. We will need to free this block later and commit + // the new free index block chain. + m_pendingFree.append(m_headFreeIndexBlock); + m_headFreeIndexBlock = indexBlock.nextFreeBlock; + } + + if (m_availableBlocks.empty()) { + // If we still don't have any available blocks, just add a block to the + // end of the file. + m_availableBlocks.add(makeEndBlock()); + } + } + + BlockIndex block = m_availableBlocks.takeFirst(); + m_uncommitted.add(block); + return block; +} + +auto BTreeDatabase::makeEndBlock() -> BlockIndex { + BlockIndex blockCount = (m_deviceSize - HeaderSize) / m_blockSize; + m_deviceSize += m_blockSize; + m_device->resize(m_deviceSize); + return blockCount; +} + +void BTreeDatabase::writeRoot() { + DataStreamIODevice ds(m_device); + // First write the root info to whichever section we are not currently using + ds.seek(BTreeRootInfoStart + (m_usingAltRoot ? 0 : BTreeRootInfoSize)); + ds.write<BlockIndex>(m_headFreeIndexBlock); + ds.write<StreamOffset>(m_deviceSize); + ds.write<BlockIndex>(m_root); + ds.write<bool>(m_rootIsLeaf); + + // Then flush all the pending changes. + m_device->sync(); + + // Then switch headers by writing the single bit that switches them + m_usingAltRoot = !m_usingAltRoot; + ds.seek(BTreeRootSelectorBit); + ds.write(m_usingAltRoot); + + // Then flush this single bit write to make sure it happens before anything + // else. + m_device->sync(); +} + +void BTreeDatabase::readRoot() { + DataStreamIODevice ds(m_device); + ds.seek(BTreeRootSelectorBit); + ds.read(m_usingAltRoot); + + ds.seek(BTreeRootInfoStart + (m_usingAltRoot ? BTreeRootInfoSize : 0)); + m_headFreeIndexBlock = ds.read<BlockIndex>(); + m_deviceSize = ds.read<StreamOffset>(); + m_root = ds.read<BlockIndex>(); + m_rootIsLeaf = ds.read<bool>(); +} + +void BTreeDatabase::doCommit() { + if (m_availableBlocks.empty() && m_pendingFree.empty() && m_uncommitted.empty()) + return; + + if (!m_availableBlocks.empty() || !m_pendingFree.empty()) { + // First, read the existing head FreeIndexBlock, if it exists + FreeIndexBlock indexBlock = FreeIndexBlock{InvalidBlockIndex, {}}; + if (m_headFreeIndexBlock != InvalidBlockIndex) { + indexBlock = readFreeIndexBlock(m_headFreeIndexBlock); + if (indexBlock.freeBlocks.size() >= maxFreeIndexLength()) { + // If the existing head free index block is full, then we should start a + // new one and leave it alone + indexBlock.nextFreeBlock = m_headFreeIndexBlock; + indexBlock.freeBlocks.clear(); + } else { + // If we are copying an existing free index block, the old free index + // block will be a newly freed block + indexBlock.freeBlocks.append(m_headFreeIndexBlock); + } + } + + // Then, we need to write all the available blocks, which are safe to write + // to, and the pending free blocks, which are NOT safe to write to, to the + // FreeIndexBlock chain. + while (true) { + if (indexBlock.freeBlocks.size() < maxFreeIndexLength() && (!m_availableBlocks.empty() || !m_pendingFree.empty())) { + // If we have room on our current FreeIndexblock, just add a block to + // it. Prioritize the pending free blocks, because we cannot use those + // to write to. + BlockIndex toAdd; + if (m_pendingFree.empty()) + toAdd = m_availableBlocks.takeFirst(); + else + toAdd = m_pendingFree.takeFirst(); + + indexBlock.freeBlocks.append(toAdd); + } else { + // If our index block is full OR we are out of blocks to free, then + // need to write a new head free index block. + if (m_availableBlocks.empty()) + m_headFreeIndexBlock = makeEndBlock(); + else + m_headFreeIndexBlock = m_availableBlocks.takeFirst(); + writeFreeIndexBlock(m_headFreeIndexBlock, indexBlock); + + // If we're out of blocks to free, then we're done + if (m_availableBlocks.empty() && m_pendingFree.empty()) + break; + + indexBlock.nextFreeBlock = m_headFreeIndexBlock; + indexBlock.freeBlocks.clear(); + } + } + } + + writeRoot(); + + m_uncommitted.clear(); +} + +void BTreeDatabase::checkIfOpen(char const* methodName, bool shouldBeOpen) const { + if (shouldBeOpen && !m_open) + throw DBException::format("BTreeDatabase method '%s' called when not open, must be open.", methodName); + else if (!shouldBeOpen && m_open) + throw DBException::format("BTreeDatabase method '%s' called when open, cannot call when open.", methodName); +} + +void BTreeDatabase::checkBlockIndex(size_t blockIndex) const { + BlockIndex blockCount = (m_deviceSize - HeaderSize) / m_blockSize; + if (blockIndex >= blockCount) + throw DBException::format("blockIndex: %s out of block range", blockIndex); +} + +void BTreeDatabase::checkKeySize(ByteArray const& k) const { + if (k.size() != m_keySize) + throw DBException::format("Wrong key size %s", k.size()); +} + +uint32_t BTreeDatabase::maxFreeIndexLength() const { + return (m_blockSize - 2 - sizeof(BlockIndex) - 4) / sizeof(BlockIndex); +} + +BTreeSha256Database::BTreeSha256Database() { + setKeySize(32); +} + +BTreeSha256Database::BTreeSha256Database(String const& contentIdentifier) { + setKeySize(32); + setContentIdentifier(contentIdentifier); +} + +bool BTreeSha256Database::contains(ByteArray const& key) { + return BTreeDatabase::contains(sha256(key)); +} + +Maybe<ByteArray> BTreeSha256Database::find(ByteArray const& key) { + return BTreeDatabase::find(sha256(key)); +} + +bool BTreeSha256Database::insert(ByteArray const& key, ByteArray const& value) { + return BTreeDatabase::insert(sha256(key), value); +} + +bool BTreeSha256Database::remove(ByteArray const& key) { + return BTreeDatabase::remove(sha256(key)); +} + +bool BTreeSha256Database::contains(String const& key) { + return BTreeDatabase::contains(sha256(key)); +} + +Maybe<ByteArray> BTreeSha256Database::find(String const& key) { + return BTreeDatabase::find(sha256(key)); +} + +bool BTreeSha256Database::insert(String const& key, ByteArray const& value) { + return BTreeDatabase::insert(sha256(key), value); +} + +bool BTreeSha256Database::remove(String const& key) { + return BTreeDatabase::remove(sha256(key)); +} + +} diff --git a/source/core/StarBTreeDatabase.hpp b/source/core/StarBTreeDatabase.hpp new file mode 100644 index 0000000..aff45d3 --- /dev/null +++ b/source/core/StarBTreeDatabase.hpp @@ -0,0 +1,344 @@ +#ifndef STAR_BTREE_DATABASE_HPP +#define STAR_BTREE_DATABASE_HPP + +#include "StarSet.hpp" +#include "StarBTree.hpp" +#include "StarLruCache.hpp" +#include "StarDataStreamDevices.hpp" +#include "StarThread.hpp" + +namespace Star { + +STAR_EXCEPTION(DBException, IOException); + +class BTreeDatabase { +public: + uint32_t const ContentIdentifierStringSize = 16; + + BTreeDatabase(); + BTreeDatabase(String const& contentIdentifier, size_t keySize); + ~BTreeDatabase(); + + // The underlying device will be allocated in "blocks" of this size. + // The larger the block size, the larger that index and leaf nodes can be + // before they need to be split, but it also means that more space is wasted + // for index or leaf nodes that are not completely full. Cannot be changed + // once the database is opened. Defaults to 2048. + uint32_t blockSize() const; + void setBlockSize(uint32_t blockSize); + + // Constant size of the database keys. Should be much smaller than the block + // size, cannot be changed once a database is opened. Defaults zero, which + // is invalid, so must be set if opening a new database. + uint32_t keySize() const; + void setKeySize(uint32_t keySize); + + // Must be no greater than ContentIdentifierStringSize large. May not be + // called when the database is opened. + String contentIdentifier() const; + void setContentIdentifier(String contentIdentifier); + + // Cache size for index nodes, defaults to 64 + uint32_t indexCacheSize() const; + void setIndexCacheSize(uint32_t indexCacheSize); + + // If true, very write operation will immediately result in a commit. + // Defaults to true. + bool autoCommit() const; + void setAutoCommit(bool autoCommit); + + IODevicePtr ioDevice() const; + void setIODevice(IODevicePtr device); + + // If an existing database is opened, this will update the key size, block + // size, and content identifier with those from the opened database. + // Otherwise, it will use the currently set values. Returns true if a new + // database was created, false if an existing database was found and opened. + bool open(); + + bool isOpen() const; + + bool contains(ByteArray const& k); + + Maybe<ByteArray> find(ByteArray const& k); + List<pair<ByteArray, ByteArray>> find(ByteArray const& lower, ByteArray const& upper); + + void forEach(ByteArray const& lower, ByteArray const& upper, function<void(ByteArray, ByteArray)> v); + void forAll(function<void(ByteArray, ByteArray)> v); + + // Returns true if a value was overwritten + bool insert(ByteArray const& k, ByteArray const& data); + + // Returns true if the element was found and removed + bool remove(ByteArray const& k); + + // Remove all elements in the given range, returns keys removed. + List<ByteArray> remove(ByteArray const& lower, ByteArray const& upper); + + uint64_t recordCount(); + + // The depth of the index nodes in this database + uint8_t indexLevels(); + + uint32_t totalBlockCount(); + uint32_t freeBlockCount(); + uint32_t indexBlockCount(); + uint32_t leafBlockCount(); + + void commit(); + void rollback(); + + void close(bool closeDevice = false); + +private: + typedef uint32_t BlockIndex; + static BlockIndex const InvalidBlockIndex = (BlockIndex)(-1); + static uint32_t const HeaderSize = 512; + + // 8 byte magic file identifier + static char const* const VersionMagic; + static uint32_t const VersionMagicSize = 8; + // 2 byte leaf and index start markers. + static char const* const FreeIndexMagic; + static char const* const IndexMagic; + static char const* const LeafMagic; + // static uint32_t const BlockMagicSize = 2; + static size_t const BTreeRootSelectorBit = 32; + static size_t const BTreeRootInfoStart = 33; + static size_t const BTreeRootInfoSize = 17; + + struct FreeIndexBlock { + BlockIndex nextFreeBlock; + List<BlockIndex> freeBlocks; + }; + + struct IndexNode { + size_t pointerCount() const; + BlockIndex pointer(size_t i) const; + void updatePointer(size_t i, BlockIndex p); + + ByteArray const& keyBefore(size_t i) const; + void updateKeyBefore(size_t i, ByteArray k); + + void removeBefore(size_t i); + void insertAfter(size_t i, ByteArray k, BlockIndex p); + + uint8_t indexLevel() const; + void setIndexLevel(uint8_t indexLevel); + + // count is number of elements to shift left *including* right's beginPointer + void shiftLeft(ByteArray const& mid, IndexNode& right, size_t count); + + // count is number of elements to shift right + void shiftRight(ByteArray const& mid, IndexNode& left, size_t count); + + // i should be index of pointer that will be the new beginPointer of right + // node (cannot be 0). + ByteArray split(IndexNode& right, size_t i); + + struct Element { + ByteArray key; + BlockIndex pointer; + }; + typedef List<Element> ElementList; + + BlockIndex self; + uint8_t level; + Maybe<BlockIndex> beginPointer; + ElementList pointers; + }; + + struct LeafNode { + size_t count() const; + ByteArray const& key(size_t i) const; + ByteArray const& data(size_t i) const; + + void insert(size_t i, ByteArray k, ByteArray d); + void remove(size_t i); + + // count is number of elements to shift left + void shiftLeft(LeafNode& right, size_t count); + + // count is number of elements to shift right + void shiftRight(LeafNode& left, size_t count); + + // i should be index of element that will be the new start of right node. + // Returns right index node. + void split(LeafNode& right, size_t i); + + struct Element { + ByteArray key; + ByteArray data; + }; + typedef List<Element> ElementList; + + BlockIndex self; + ElementList elements; + }; + + struct BTreeImpl { + typedef ByteArray Key; + typedef ByteArray Data; + typedef BlockIndex Pointer; + + typedef shared_ptr<IndexNode> Index; + typedef shared_ptr<LeafNode> Leaf; + + Pointer rootPointer(); + bool rootIsLeaf(); + void setNewRoot(Pointer pointer, bool isLeaf); + + Index createIndex(Pointer beginPointer); + Index loadIndex(Pointer pointer); + bool indexNeedsShift(Index const& index); + bool indexShift(Index const& left, Key const& mid, Index const& right); + Maybe<pair<Key, Index>> indexSplit(Index const& index); + Pointer storeIndex(Index index); + void deleteIndex(Index index); + + Leaf createLeaf(); + Leaf loadLeaf(Pointer pointer); + bool leafNeedsShift(Leaf const& l); + bool leafShift(Leaf& left, Leaf& right); + Maybe<Leaf> leafSplit(Leaf& leaf); + Pointer storeLeaf(Leaf leaf); + void deleteLeaf(Leaf leaf); + + size_t indexPointerCount(Index const& index); + Pointer indexPointer(Index const& index, size_t i); + void indexUpdatePointer(Index& index, size_t i, Pointer p); + Key indexKeyBefore(Index const& index, size_t i); + void indexUpdateKeyBefore(Index& index, size_t i, Key k); + void indexRemoveBefore(Index& index, size_t i); + void indexInsertAfter(Index& index, size_t i, Key k, Pointer p); + size_t indexLevel(Index const& index); + void setIndexLevel(Index& index, size_t indexLevel); + + size_t leafElementCount(Leaf const& leaf); + Key leafKey(Leaf const& leaf, size_t i); + Data leafData(Leaf const& leaf, size_t i); + void leafInsert(Leaf& leaf, size_t i, Key k, Data d); + void leafRemove(Leaf& leaf, size_t i); + Maybe<Pointer> nextLeaf(Leaf const& leaf); + void setNextLeaf(Leaf& leaf, Maybe<Pointer> n); + + BTreeDatabase* parent; + }; + + void readBlock(BlockIndex blockIndex, size_t blockOffset, char* block, size_t size) const; + ByteArray readBlock(BlockIndex blockIndex) const; + void updateBlock(BlockIndex blockIndex, ByteArray const& block); + + void rawReadBlock(BlockIndex blockIndex, size_t blockOffset, char* block, size_t size) const; + void rawWriteBlock(BlockIndex blockIndex, size_t blockOffset, char const* block, size_t size) const; + + void updateHeadFreeIndexBlock(BlockIndex newHead); + + FreeIndexBlock readFreeIndexBlock(BlockIndex blockIndex); + void writeFreeIndexBlock(BlockIndex blockIndex, FreeIndexBlock indexBlock); + + uint32_t leafSize(shared_ptr<LeafNode> const& leaf) const; + uint32_t maxIndexPointers() const; + + uint32_t dataSize(ByteArray const& d) const; + List<BlockIndex> leafTailBlocks(BlockIndex leafPointer); + + void freeBlock(BlockIndex b); + BlockIndex reserveBlock(); + BlockIndex makeEndBlock(); + + void dirty(); + void writeRoot(); + void readRoot(); + void doCommit(); + + void checkIfOpen(char const* methodName, bool shouldBeOpen) const; + void checkBlockIndex(size_t blockIndex) const; + void checkKeySize(ByteArray const& k) const; + uint32_t maxFreeIndexLength() const; + + mutable ReadersWriterMutex m_lock; + + BTreeMixin<BTreeImpl> m_impl; + + IODevicePtr m_device; + bool m_open; + + uint32_t m_blockSize; + String m_contentIdentifier; + uint32_t m_keySize; + + bool m_autoCommit; + + // Reading values can mutate the index cache, so the index cache is kept + // using a different lock. It is only necessary to acquire this lock when + // NOT holding the main writer lock, because if the main writer lock is held + // then no other method would be loading an index anyway. + mutable SpinLock m_indexCacheSpinLock; + LruCache<BlockIndex, shared_ptr<IndexNode>> m_indexCache; + + BlockIndex m_headFreeIndexBlock; + StreamOffset m_deviceSize; + BlockIndex m_root; + bool m_rootIsLeaf; + bool m_usingAltRoot; + bool m_dirty; + + // Blocks that can be freely allocated and written to without violating + // atomic consistency + Set<BlockIndex> m_availableBlocks; + + // Blocks to be freed on next commit. + Deque<BlockIndex> m_pendingFree; + + // Blocks that have been written in uncommitted portions of the tree. + Set<BlockIndex> m_uncommitted; +}; + +// Version of BTreeDatabase that hashes keys with SHA-256 to produce a unique +// constant size key. +class BTreeSha256Database : private BTreeDatabase { +public: + BTreeSha256Database(); + BTreeSha256Database(String const& contentIdentifier); + + // Keys can be arbitrary size, actual key is the SHA-256 checksum of the key. + bool contains(ByteArray const& key); + Maybe<ByteArray> find(ByteArray const& key); + bool insert(ByteArray const& key, ByteArray const& value); + bool remove(ByteArray const& key); + + // Convenience string versions of access methods. Equivalent to the utf8 + // bytes of the string minus the null terminator. + bool contains(String const& key); + Maybe<ByteArray> find(String const& key); + bool insert(String const& key, ByteArray const& value); + bool remove(String const& key); + + using BTreeDatabase::ContentIdentifierStringSize; + using BTreeDatabase::blockSize; + using BTreeDatabase::setBlockSize; + using BTreeDatabase::contentIdentifier; + using BTreeDatabase::setContentIdentifier; + using BTreeDatabase::indexCacheSize; + using BTreeDatabase::setIndexCacheSize; + using BTreeDatabase::autoCommit; + using BTreeDatabase::setAutoCommit; + using BTreeDatabase::ioDevice; + using BTreeDatabase::setIODevice; + using BTreeDatabase::open; + using BTreeDatabase::isOpen; + using BTreeDatabase::recordCount; + using BTreeDatabase::indexLevels; + using BTreeDatabase::totalBlockCount; + using BTreeDatabase::freeBlockCount; + using BTreeDatabase::indexBlockCount; + using BTreeDatabase::leafBlockCount; + using BTreeDatabase::commit; + using BTreeDatabase::rollback; + using BTreeDatabase::close; +}; + +} + +#endif diff --git a/source/core/StarBiMap.hpp b/source/core/StarBiMap.hpp new file mode 100644 index 0000000..3953fde --- /dev/null +++ b/source/core/StarBiMap.hpp @@ -0,0 +1,419 @@ +#ifndef STAR_BI_MAP_HPP +#define STAR_BI_MAP_HPP + +#include "StarString.hpp" + +namespace Star { + +// Bi-directional map of unique sets of elements with quick map access from +// either the left or right element to the other side. Every left side value +// must be unique from every other left side value and the same for the right +// side. +template <typename LeftT, + typename RightT, + typename LeftMapT = Map<LeftT, RightT const*>, + typename RightMapT = Map<RightT, LeftT const*>> +class BiMap { +public: + typedef LeftT Left; + typedef RightT Right; + typedef LeftMapT LeftMap; + typedef RightMapT RightMap; + + typedef pair<Left, Right> value_type; + + struct BiMapIterator { + BiMapIterator& operator++(); + BiMapIterator operator++(int); + + bool operator==(BiMapIterator const& rhs) const; + bool operator!=(BiMapIterator const& rhs) const; + + pair<Left const&, Right const&> operator*() const; + + typename LeftMap::const_iterator iterator; + }; + + typedef BiMapIterator iterator; + typedef iterator const_iterator; + + template <typename Collection> + static BiMap from(Collection const& c); + + BiMap(); + BiMap(BiMap const& map); + + template <typename InputIterator> + BiMap(InputIterator beg, InputIterator end); + + BiMap(std::initializer_list<value_type> list); + + List<Left> leftValues() const; + List<Right> rightValues() const; + List<value_type> pairs() const; + + bool hasLeftValue(Left const& left) const; + bool hasRightValue(Right const& right) const; + + Right const& getRight(Left const& left) const; + Left const& getLeft(Right const& right) const; + + Right valueRight(Left const& left, Right const& def = Right()) const; + Left valueLeft(Right const& right, Left const& def = Left()) const; + + Maybe<Right> maybeRight(Left const& left) const; + + Maybe<Left> maybeLeft(Right const& right) const; + + Right takeRight(Left const& left); + Left takeLeft(Right const& right); + + Maybe<Right> maybeTakeRight(Left const& left); + Maybe<Left> maybeTakeLeft(Right const& right); + + Right const* rightPtr(Left const& left) const; + Left const* leftPtr(Right const& right) const; + + BiMap& operator=(BiMap const& map); + + pair<iterator, bool> insert(value_type const& val); + + // Returns true if value was inserted, false if either the left or right side + // already existed. + bool insert(Left const& left, Right const& right); + + // Throws an exception if the pair cannot be inserted + void add(Left const& left, Right const& right); + void add(value_type const& value); + + // Overwrites the left / right mapping regardless of whether each side + // already exists. + void overwrite(Left const& left, Right const& right); + void overwrite(value_type const& value); + + // Removes the pair with the given left side, returns true if this pair was + // found, false otherwise. + bool removeLeft(Left const& left); + + // Removes the pair with the given right side, returns true if this pair was + // found, false otherwise. + bool removeRight(Right const& right); + + const_iterator begin() const; + const_iterator end() const; + + size_t size() const; + + void clear(); + + bool empty() const; + + bool operator==(BiMap const& m) const; + +private: + LeftMap m_leftMap; + RightMap m_rightMap; +}; + +template <typename Left, typename Right, typename LeftHash = Star::hash<Left>, typename RightHash = Star::hash<Right>> +using BiHashMap = BiMap<Left, Right, StableHashMap<Left, Right const*, LeftHash>, StableHashMap<Right, Left const*, RightHash>>; + +// Case insensitive Enum <-> String map +template <typename EnumType> +using EnumMap = BiMap<EnumType, + String, + Map<EnumType, String const*>, + StableHashMap<String, EnumType const*, CaseInsensitiveStringHash, CaseInsensitiveStringCompare>>; + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMapIterator::operator++() -> BiMapIterator & { + ++iterator; + return *this; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMapIterator::operator++(int) -> BiMapIterator { + BiMapIterator last{iterator}; + ++iterator; + return last; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMapIterator::operator==(BiMapIterator const& rhs) const { + return iterator == rhs.iterator; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMapIterator::operator!=(BiMapIterator const& rhs) const { + return iterator != rhs.iterator; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +pair<LeftT const&, RightT const&> BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMapIterator::operator*() const { + return {iterator->first, *iterator->second}; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +template <typename Collection> +BiMap<LeftT, RightT, LeftMapT, RightMapT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::from(Collection const& c) { + return BiMap(c.begin(), c.end()); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMap() {} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMap(BiMap const& map) + : BiMap(map.begin(), map.end()) {} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +template <typename InputIterator> +BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMap(InputIterator beg, InputIterator end) { + while (beg != end) { + insert(*beg); + ++beg; + } +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +BiMap<LeftT, RightT, LeftMapT, RightMapT>::BiMap(std::initializer_list<value_type> list) { + for (value_type const& v : list) { + if (!insert(v.first, v.second)) + throw MapException::format("Repeat pair in BiMap initializer_list construction: (%s, %s)", outputAny(v.first), outputAny(v.second)); + } +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +List<LeftT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::leftValues() const { + return m_leftMap.keys(); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +List<RightT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::rightValues() const { + return m_rightMap.keys(); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::pairs() const -> List<value_type> { + List<value_type> values; + for (auto const& p : *this) + values.append(p); + return values; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::hasLeftValue(Left const& left) const { + return m_leftMap.contains(left); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::hasRightValue(Right const& right) const { + return m_rightMap.contains(right); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +RightT const& BiMap<LeftT, RightT, LeftMapT, RightMapT>::getRight(Left const& left) const { + return *m_leftMap.get(left); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +LeftT const& BiMap<LeftT, RightT, LeftMapT, RightMapT>::getLeft(Right const& right) const { + return *m_rightMap.get(right); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +RightT BiMap<LeftT, RightT, LeftMapT, RightMapT>::valueRight(Left const& left, Right const& def) const { + return maybeRight(left).value(def); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +LeftT BiMap<LeftT, RightT, LeftMapT, RightMapT>::valueLeft(Right const& right, Left const& def) const { + return maybeLeft(right).value(def); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +Maybe<RightT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::maybeRight(Left const& left) const { + auto i = m_leftMap.find(left); + if (i != m_leftMap.end()) + return *i->second; + return {}; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +Maybe<LeftT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::maybeLeft(Right const& right) const { + auto i = m_rightMap.find(right); + if (i != m_rightMap.end()) + return *i->second; + return {}; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +RightT BiMap<LeftT, RightT, LeftMapT, RightMapT>::takeRight(Left const& left) { + if (auto right = maybeTakeRight(left)) + return right.take(); + throw MapException::format("No such key in BiMap::takeRight", outputAny(left)); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +LeftT BiMap<LeftT, RightT, LeftMapT, RightMapT>::takeLeft(Right const& right) { + if (auto left = maybeTakeLeft(right)) + return left.take(); + throw MapException::format("No such key in BiMap::takeLeft", outputAny(right)); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +Maybe<RightT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::maybeTakeRight(Left const& left) { + if (auto rightPtr = m_leftMap.maybeTake(left).value()) { + Right right = *rightPtr; + m_rightMap.remove(*rightPtr); + return right; + } else { + return {}; + } +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +Maybe<LeftT> BiMap<LeftT, RightT, LeftMapT, RightMapT>::maybeTakeLeft(Right const& right) { + if (auto leftPtr = m_rightMap.maybeTake(right).value()) { + Left left = *leftPtr; + m_leftMap.remove(*leftPtr); + return left; + } else { + return {}; + } +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +RightT const* BiMap<LeftT, RightT, LeftMapT, RightMapT>::rightPtr(Left const& left) const { + return m_leftMap.value(left); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +LeftT const* BiMap<LeftT, RightT, LeftMapT, RightMapT>::leftPtr(Right const& right) const { + return m_rightMap.value(right); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +BiMap<LeftT, RightT, LeftMapT, RightMapT>& BiMap<LeftT, RightT, LeftMapT, RightMapT>::operator=(BiMap const& map) { + if (this != &map) { + clear(); + for (auto const& p : map) + insert(p); + } + return *this; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::insert(value_type const& val) -> pair<iterator, bool> { + auto leftRes = m_leftMap.insert(make_pair(val.first, nullptr)); + if (!leftRes.second) + return {BiMapIterator{leftRes.first}, false}; + + auto rightRes = m_rightMap.insert(make_pair(val.second, nullptr)); + starAssert(rightRes.second == true); + leftRes.first->second = &rightRes.first->first; + rightRes.first->second = &leftRes.first->first; + return {BiMapIterator{leftRes.first}, true}; +}; + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::insert(Left const& left, Right const& right) { + return insert(make_pair(left, right)).second; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +void BiMap<LeftT, RightT, LeftMapT, RightMapT>::add(Left const& left, Right const& right) { + if (m_leftMap.contains(left)) + throw MapException(strf("BiMap already contains left side value '%s'", outputAny(left))); + + if (m_rightMap.contains(right)) + throw MapException(strf("BiMap already contains right side value '%s'", outputAny(right))); + + insert(left, right); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +void BiMap<LeftT, RightT, LeftMapT, RightMapT>::add(value_type const& value) { + add(value.first, value.second); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +void BiMap<LeftT, RightT, LeftMapT, RightMapT>::overwrite(Left const& left, Right const& right) { + removeLeft(left); + removeRight(right); + insert(left, right); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +void BiMap<LeftT, RightT, LeftMapT, RightMapT>::overwrite(value_type const& value) { + return overwrite(value.first, value.second); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::removeLeft(Left const& left) { + if (auto right = m_leftMap.value(left)) { + m_rightMap.remove(*right); + m_leftMap.remove(left); + return true; + } + + return false; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::removeRight(Right const& right) { + if (auto left = m_rightMap.value(right)) { + m_leftMap.remove(*left); + m_rightMap.remove(right); + return true; + } + + return false; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::begin() const -> const_iterator { + return BiMapIterator{m_leftMap.begin()}; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +auto BiMap<LeftT, RightT, LeftMapT, RightMapT>::end() const -> const_iterator { + return BiMapIterator{m_leftMap.end()}; +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +size_t BiMap<LeftT, RightT, LeftMapT, RightMapT>::size() const { + return m_leftMap.size(); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +void BiMap<LeftT, RightT, LeftMapT, RightMapT>::clear() { + m_leftMap.clear(); + m_rightMap.clear(); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::empty() const { + return m_leftMap.empty(); +} + +template <typename LeftT, typename RightT, typename LeftMapT, typename RightMapT> +bool BiMap<LeftT, RightT, LeftMapT, RightMapT>::operator==(BiMap const& m) const { + if (&m == this) + return true; + + if (size() != m.size()) + return false; + + for (auto const& pair : *this) { + if (auto p = m.rightPtr(pair.first)) + if (!p || *p != pair.second) + return false; + } + + return true; +} + +} + +#endif diff --git a/source/core/StarBlockAllocator.hpp b/source/core/StarBlockAllocator.hpp new file mode 100644 index 0000000..5051e66 --- /dev/null +++ b/source/core/StarBlockAllocator.hpp @@ -0,0 +1,268 @@ +#ifndef STAR_BLOCK_ALLOCATOR_HPP +#define STAR_BLOCK_ALLOCATOR_HPP + +#include <array> +#include <vector> +#include <unordered_map> +#include <limits> +#include <typeindex> + +#include "StarException.hpp" + +namespace Star { + +// Constant size only allocator using fixed size blocks of memory. much faster +// than general purpose allocators, but not thread safe. Useful as the +// allocator for containers that mostly allocate one element at a time, such as +// std::list, std::map, std::set etc. +template <typename T, size_t BlockSize> +class BlockAllocator { +public: + typedef T value_type; + + typedef T* pointer; + typedef T const* const_pointer; + + typedef T& reference; + typedef T const& const_reference; + + // Allocator can be shared, but since it is NOT thread safe this should not + // be done by default. + typedef std::false_type propagate_on_container_copy_assignment; + typedef std::true_type propagate_on_container_move_assignment; + typedef std::true_type propagate_on_container_swap; + + template <class U> + struct rebind { + typedef BlockAllocator<U, BlockSize> other; + }; + + BlockAllocator(); + // Copy constructed BlockAllocators of the same type share underlying + // resources. + BlockAllocator(BlockAllocator const& other) = default; + BlockAllocator(BlockAllocator&& other) = default; + // Copy constructed BlockAllocators of different type share no resources + template <class U> + BlockAllocator(BlockAllocator<U, BlockSize> const& other); + + BlockAllocator& operator=(BlockAllocator const& rhs) = default; + BlockAllocator& operator=(BlockAllocator&& rhs) = default; + + // If n is != 1, will fall back on std::allocator<T> + T* allocate(size_t n); + void deallocate(T* p, size_t n); + + template <typename... Args> + void construct(pointer p, Args&&... args) const; + void destroy(pointer p) const; + + // BlockAllocator will always be != to any other BlockAllocator instance + template <class U> + bool operator==(BlockAllocator<U, BlockSize> const& rhs) const; + template <class U> + bool operator!=(BlockAllocator<U, BlockSize> const& rhs) const; + +private: + template <typename OtherT, size_t OtherBlockSize> + friend class BlockAllocator; + + using ChunkIndex = + std::conditional_t<BlockSize <= std::numeric_limits<uint8_t>::max(), uint8_t, + std::conditional_t<BlockSize <= std::numeric_limits<uint16_t>::max(), uint16_t, + std::conditional_t<BlockSize <= std::numeric_limits<uint32_t>::max(), uint32_t, + std::conditional_t<BlockSize <= std::numeric_limits<uint64_t>::max(), uint64_t, uintmax_t>>>>; + + static ChunkIndex const NullChunkIndex = std::numeric_limits<ChunkIndex>::max(); + + struct Unallocated { + ChunkIndex prev; + ChunkIndex next; + }; + + typedef std::aligned_union_t<0, T, Unallocated> Chunk; + + struct Block { + T* allocate(); + void deallocate(T* ptr); + + bool full() const; + bool empty() const; + + Chunk* chunkPointer(ChunkIndex chunkIndex); + + std::array<Chunk, BlockSize> chunks; + ChunkIndex firstUnallocated = NullChunkIndex; + ChunkIndex allocationCount = 0; + }; + + struct Data { + std::vector<unique_ptr<Block>> blocks; + Block* unfilledBlock; + std::allocator<T> multiAllocator; + }; + + typedef std::unordered_map<std::type_index, shared_ptr<void>> BlockAllocatorFamily; + + static Data* getAllocatorData(BlockAllocatorFamily& family); + + shared_ptr<BlockAllocatorFamily> m_family; + Data* m_data; +}; + +template <typename T, size_t BlockSize> +BlockAllocator<T, BlockSize>::BlockAllocator() { + m_family = make_shared<BlockAllocatorFamily>(); + m_data = getAllocatorData(*m_family); + m_data->blocks.reserve(32); + m_data->unfilledBlock = nullptr; +} + +template <typename T, size_t BlockSize> +template <class U> +BlockAllocator<T, BlockSize>::BlockAllocator(BlockAllocator<U, BlockSize> const& other) + : m_family(other.m_family) { + m_data = getAllocatorData(*m_family); +} + +template <typename T, size_t BlockSize> +T* BlockAllocator<T, BlockSize>::allocate(size_t n) { + if (n == 1) { + if (m_data->unfilledBlock == nullptr) { + for (auto const& p : m_data->blocks) { + if (!p->full()) { + m_data->unfilledBlock = p.get(); + break; + } + } + + if (!m_data->unfilledBlock) { + auto block = make_unique<Block>(); + m_data->unfilledBlock = block.get(); + auto sortedPosition = std::lower_bound(m_data->blocks.begin(), m_data->blocks.end(), block.get(), [](std::unique_ptr<Block> const& a, Block* b) { + return a.get() < b; + }); + m_data->blocks.insert(sortedPosition, move(block)); + } + } + + auto allocated = m_data->unfilledBlock->allocate(); + if (m_data->unfilledBlock->full()) + m_data->unfilledBlock = nullptr; + return allocated; + } else { + return m_data->multiAllocator.allocate(n); + } +} + +template <typename T, size_t BlockSize> +void BlockAllocator<T, BlockSize>::deallocate(T* p, size_t n) { + if (n == 1) { + starAssert(p); + + auto i = std::upper_bound(m_data->blocks.begin(), m_data->blocks.end(), p, [](T* a, std::unique_ptr<Block> const& b) { + return a < (T*)b->chunkPointer(0); + }); + + starAssert(i != m_data->blocks.begin()); + --i; + + (*i)->deallocate(p); + + if (!m_data->unfilledBlock) { + m_data->unfilledBlock = i->get(); + } else if ((*i)->empty()) { + if (m_data->unfilledBlock != i->get()) + m_data->blocks.erase(i); + } + } else { + m_data->multiAllocator.deallocate(p, n); + } +} + +template <typename T, size_t BlockSize> +template <typename... Args> +void BlockAllocator<T, BlockSize>::construct(pointer p, Args&&... args) const { + new (p) T(forward<Args>(args)...); +} + +template <typename T, size_t BlockSize> +void BlockAllocator<T, BlockSize>::destroy(pointer p) const { + p->~T(); +} + +template <typename T, size_t BlockSize> +template <class U> +bool BlockAllocator<T, BlockSize>::operator==(BlockAllocator<U, BlockSize> const& rhs) const { + return m_family == rhs.m_family; +} + +template <typename T, size_t BlockSize> +template <class U> +bool BlockAllocator<T, BlockSize>::operator!=(BlockAllocator<U, BlockSize> const& rhs) const { + return m_family != rhs.m_family; +} + +template <typename T, size_t BlockSize> +T* BlockAllocator<T, BlockSize>::Block::allocate() { + starAssert(allocationCount < BlockSize); + + T* allocated; + if (firstUnallocated == NullChunkIndex) { + allocated = (T*)chunkPointer(allocationCount); + } else { + void* chunk = chunkPointer(firstUnallocated); + starAssert(((Unallocated*)chunk)->prev == NullChunkIndex); + firstUnallocated = ((Unallocated*)chunk)->next; + if (firstUnallocated != NullChunkIndex) + ((Unallocated*)chunkPointer(firstUnallocated))->prev = NullChunkIndex; + allocated = (T*)chunk; + } + + ++allocationCount; + return allocated; +} + +template <typename T, size_t BlockSize> +void BlockAllocator<T, BlockSize>::Block::deallocate(T* ptr) { + starAssert(allocationCount > 0); + + ChunkIndex chunkIndex = ptr - (T*)chunkPointer(0); + starAssert((T*)chunkPointer(chunkIndex) == ptr); + + auto c = (Unallocated*)chunkPointer(chunkIndex); + c->prev = NullChunkIndex; + c->next = firstUnallocated; + if (firstUnallocated != NullChunkIndex) + ((Unallocated*)chunkPointer(firstUnallocated))->prev = chunkIndex; + firstUnallocated = chunkIndex; + --allocationCount; +} + +template <typename T, size_t BlockSize> +bool BlockAllocator<T, BlockSize>::Block::full() const { + return allocationCount == BlockSize; +} + +template <typename T, size_t BlockSize> +bool BlockAllocator<T, BlockSize>::Block::empty() const { + return allocationCount == 0; +} + +template <typename T, size_t BlockSize> +auto BlockAllocator<T, BlockSize>::Block::chunkPointer(ChunkIndex chunkIndex) -> Chunk* { + starAssert(chunkIndex < BlockSize); + return &chunks[chunkIndex]; +} + +template <typename T, size_t BlockSize> +typename BlockAllocator<T, BlockSize>::Data* BlockAllocator<T, BlockSize>::getAllocatorData(BlockAllocatorFamily& family) { + auto& dataptr = family[typeid(Data)]; + if (!dataptr) + dataptr = make_shared<Data>(); + return (Data*)dataptr.get(); +} + +} + +#endif diff --git a/source/core/StarBuffer.cpp b/source/core/StarBuffer.cpp new file mode 100644 index 0000000..fbed59a --- /dev/null +++ b/source/core/StarBuffer.cpp @@ -0,0 +1,287 @@ +#include "StarBuffer.hpp" +#include "StarMathCommon.hpp" +#include "StarIODevice.hpp" +#include "StarFormat.hpp" + +namespace Star { + +Buffer::Buffer() + : m_pos(0) { + setMode(IOMode::ReadWrite); +} + +Buffer::Buffer(size_t initialSize) + : Buffer() { + reset(initialSize); +} + +Buffer::Buffer(ByteArray b) + : Buffer() { + reset(move(b)); +} + +Buffer::Buffer(Buffer const& buffer) + : Buffer() { + operator=(buffer); +} + +Buffer::Buffer(Buffer&& buffer) + : Buffer() { + operator=(move(buffer)); +} + +StreamOffset Buffer::pos() { + return m_pos; +} + +void Buffer::seek(StreamOffset pos, IOSeek mode) { + StreamOffset newPos = m_pos; + if (mode == IOSeek::Absolute) + newPos = pos; + else if (mode == IOSeek::Relative) + newPos += pos; + else if (mode == IOSeek::End) + newPos = m_bytes.size() - pos; + m_pos = newPos; +} + +void Buffer::resize(StreamOffset size) { + data().resize((size_t)size); +} + +bool Buffer::atEnd() { + return m_pos >= m_bytes.size(); +} + +size_t Buffer::read(char* data, size_t len) { + size_t l = doRead(m_pos, data, len); + m_pos += l; + return l; +} + +size_t Buffer::write(char const* data, size_t len) { + size_t l = doWrite(m_pos, data, len); + m_pos += l; + return l; +} + +size_t Buffer::readAbsolute(StreamOffset readPosition, char* data, size_t len) { + size_t rpos = readPosition; + if ((StreamOffset)rpos != readPosition) + throw IOException("Error, readPosition out of range"); + + return doRead(rpos, data, len); +} + +size_t Buffer::writeAbsolute(StreamOffset writePosition, char const* data, size_t len) { + size_t wpos = writePosition; + if ((StreamOffset)wpos != writePosition) + throw IOException("Error, writePosition out of range"); + + return doWrite(wpos, data, len); +} + +void Buffer::open(IOMode mode) { + setMode(mode); + if (mode & IOMode::Write && mode & IOMode::Truncate) + resize(0); + if (mode & IOMode::Append) + seek(0, IOSeek::End); +} + +String Buffer::deviceName() const { + return strf("Buffer <%s>", this); +} + +StreamOffset Buffer::size() { + return m_bytes.size(); +} + +ByteArray& Buffer::data() { + return m_bytes; +} + +ByteArray const& Buffer::data() const { + return m_bytes; +} + +ByteArray Buffer::takeData() { + ByteArray ret = move(m_bytes); + reset(0); + return ret; +} + +char* Buffer::ptr() { + return data().ptr(); +} + +char const* Buffer::ptr() const { + return m_bytes.ptr(); +} + +size_t Buffer::dataSize() const { + return m_bytes.size(); +} + +void Buffer::reserve(size_t size) { + data().reserve(size); +} + +void Buffer::clear() { + m_pos = 0; + m_bytes.clear(); +} + +bool Buffer::empty() const { + return m_bytes.empty(); +} + +void Buffer::reset(size_t newSize) { + m_pos = 0; + m_bytes.fill(newSize, 0); +} + +void Buffer::reset(ByteArray b) { + m_pos = 0; + m_bytes = move(b); +} + +Buffer& Buffer::operator=(Buffer const& buffer) { + IODevice::operator=(buffer); + m_pos = buffer.m_pos; + m_bytes = buffer.m_bytes; + return *this; +} + +Buffer& Buffer::operator=(Buffer&& buffer) { + IODevice::operator=(buffer); + m_pos = buffer.m_pos; + m_bytes = move(buffer.m_bytes); + + buffer.m_pos = 0; + buffer.m_bytes = ByteArray(); + + return *this; +} + +size_t Buffer::doRead(size_t pos, char* data, size_t len) { + if (len == 0) + return 0; + + if (!isReadable()) + throw IOException("Error, read called on non-readable Buffer"); + + if (pos >= m_bytes.size()) + return 0; + + size_t l = min(m_bytes.size() - pos, len); + memcpy(data, m_bytes.ptr() + pos, l); + return l; +} + +size_t Buffer::doWrite(size_t pos, char const* data, size_t len) { + if (len == 0) + return 0; + + if (!isWritable()) + throw EofException("Error, write called on non-writable Buffer"); + + if (pos + len > m_bytes.size()) + m_bytes.resize(pos + len); + + memcpy(m_bytes.ptr() + pos, data, len); + return len; +} + +ExternalBuffer::ExternalBuffer() + : m_pos(0), m_bytes(nullptr), m_size(0) { + setMode(IOMode::Read); +} + +ExternalBuffer::ExternalBuffer(char const* externalData, size_t len) : ExternalBuffer() { + reset(externalData, len); +} + +StreamOffset ExternalBuffer::pos() { + return m_pos; +} + +void ExternalBuffer::seek(StreamOffset pos, IOSeek mode) { + StreamOffset newPos = m_pos; + if (mode == IOSeek::Absolute) + newPos = pos; + else if (mode == IOSeek::Relative) + newPos += pos; + else if (mode == IOSeek::End) + newPos = m_size - pos; + m_pos = newPos; +} + +bool ExternalBuffer::atEnd() { + return m_pos >= m_size; +} + +size_t ExternalBuffer::read(char* data, size_t len) { + size_t l = doRead(m_pos, data, len); + m_pos += l; + return l; +} + +size_t ExternalBuffer::write(char const*, size_t) { + throw IOException("Error, ExternalBuffer is not writable"); +} + +size_t ExternalBuffer::readAbsolute(StreamOffset readPosition, char* data, size_t len) { + size_t rpos = readPosition; + if ((StreamOffset)rpos != readPosition) + throw IOException("Error, readPosition out of range"); + + return doRead(rpos, data, len); +} + +size_t ExternalBuffer::writeAbsolute(StreamOffset, char const*, size_t) { + throw IOException("Error, ExternalBuffer is not writable"); +} + +String ExternalBuffer::deviceName() const { + return strf("ExternalBuffer <%s>", this); +} + +StreamOffset ExternalBuffer::size() { + return m_size; +} + +char const* ExternalBuffer::ptr() const { + return m_bytes; +} + +size_t ExternalBuffer::dataSize() const { + return m_size; +} + +bool ExternalBuffer::empty() const { + return m_size == 0; +} + +void ExternalBuffer::reset(char const* externalData, size_t len) { + m_pos = 0; + m_bytes = externalData; + m_size = len; +} + +size_t ExternalBuffer::doRead(size_t pos, char* data, size_t len) { + if (len == 0) + return 0; + + if (!isReadable()) + throw IOException("Error, read called on non-readable Buffer"); + + if (pos >= m_size) + return 0; + + size_t l = min(m_size - pos, len); + memcpy(data, m_bytes + pos, l); + return l; +} + +} diff --git a/source/core/StarBuffer.hpp b/source/core/StarBuffer.hpp new file mode 100644 index 0000000..a92138b --- /dev/null +++ b/source/core/StarBuffer.hpp @@ -0,0 +1,122 @@ +#ifndef STAR_BUFFER_HPP +#define STAR_BUFFER_HPP + +#include "StarIODevice.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_CLASS(Buffer); +STAR_CLASS(ExternalBuffer); + +// Wraps a ByteArray to an IODevice +class Buffer : public IODevice { +public: + // Constructs buffer open ReadWrite + Buffer(); + Buffer(size_t initialSize); + Buffer(ByteArray b); + Buffer(Buffer const& buffer); + Buffer(Buffer&& buffer); + + StreamOffset pos() override; + void seek(StreamOffset pos, IOSeek mode = IOSeek::Absolute) override; + void resize(StreamOffset size) override; + bool atEnd() override; + + size_t read(char* data, size_t len) override; + size_t write(char const* data, size_t len) override; + + size_t readAbsolute(StreamOffset readPosition, char* data, size_t len) override; + size_t writeAbsolute(StreamOffset writePosition, char const* data, size_t len) override; + + void open(IOMode mode) override; + + String deviceName() const override; + + StreamOffset size() override; + + ByteArray& data(); + ByteArray const& data() const; + + // If this class holds the underlying data, then this method is cheap, and + // will move the data out of this class into the returned array, otherwise, + // this will incur a copy. Afterwards, this Buffer will be left empty. + ByteArray takeData(); + + // Returns a pointer to the beginning of the Buffer. + char* ptr(); + char const* ptr() const; + + // Same thing as size(), just size_t type (since this is in-memory) + size_t dataSize() const; + void reserve(size_t size); + + // Clears buffer, moves position to 0. + void clear(); + bool empty() const; + + // Reset buffer with new contents, moves position to 0. + void reset(size_t newSize); + void reset(ByteArray b); + + Buffer& operator=(Buffer const& buffer); + Buffer& operator=(Buffer&& buffer); + +private: + size_t doRead(size_t pos, char* data, size_t len); + size_t doWrite(size_t pos, char const* data, size_t len); + + size_t m_pos; + ByteArray m_bytes; +}; + +// Wraps an externally held sequence of bytes to a read-only IODevice +class ExternalBuffer : public IODevice { +public: + // Constructs an empty ReadOnly ExternalBuffer. + ExternalBuffer(); + // Constructs a ReadOnly ExternalBuffer pointing to the given external data, which + // must be valid for the lifetime of the ExternalBuffer. + ExternalBuffer(char const* externalData, size_t len); + + ExternalBuffer(ExternalBuffer const& buffer) = default; + ExternalBuffer& operator=(ExternalBuffer const& buffer) = default; + + StreamOffset pos() override; + void seek(StreamOffset pos, IOSeek mode = IOSeek::Absolute) override; + bool atEnd() override; + + size_t read(char* data, size_t len) override; + size_t write(char const* data, size_t len) override; + + size_t readAbsolute(StreamOffset readPosition, char* data, size_t len) override; + size_t writeAbsolute(StreamOffset writePosition, char const* data, size_t len) override; + + String deviceName() const override; + + StreamOffset size() override; + + // Returns a pointer to the beginning of the Buffer. + char const* ptr() const; + + // Same thing as size(), just size_t type (since this is in-memory) + size_t dataSize() const; + + // Clears buffer, moves position to 0. + bool empty() const; + + // Reset buffer with new contents, moves position to 0. + void reset(char const* externalData, size_t len); + +private: + size_t doRead(size_t pos, char* data, size_t len); + + size_t m_pos; + char const* m_bytes; + size_t m_size; +}; + +} + +#endif diff --git a/source/core/StarByteArray.cpp b/source/core/StarByteArray.cpp new file mode 100644 index 0000000..944071b --- /dev/null +++ b/source/core/StarByteArray.cpp @@ -0,0 +1,258 @@ +#include "StarByteArray.hpp" +#include "StarEncode.hpp" + +namespace Star { + +ByteArray ByteArray::fromCString(char const* str) { + return ByteArray(str, strlen(str)); +} + +ByteArray ByteArray::fromCStringWithNull(char const* str) { + size_t len = strlen(str); + ByteArray ba(str, len + 1); + ba[len] = 0; + return ba; +} + +ByteArray ByteArray::withReserve(size_t capacity) { + ByteArray bytes; + bytes.reserve(capacity); + return bytes; +} + +ByteArray::ByteArray() { + m_data = nullptr; + m_capacity = 0; + m_size = 0; +} + +ByteArray::ByteArray(size_t dataSize, char c) + : ByteArray() { + fill(dataSize, c); +} + +ByteArray::ByteArray(const char* data, size_t dataSize) + : ByteArray() { + append(data, dataSize); +} + +ByteArray::ByteArray(ByteArray const& b) + : ByteArray() { + operator=(b); +} + +ByteArray::ByteArray(ByteArray&& b) noexcept + : ByteArray() { + operator=(move(b)); +} + +ByteArray::~ByteArray() { + reset(); +} + +ByteArray& ByteArray::operator=(ByteArray const& b) { + if (&b != this) { + clear(); + append(b); + } + + return *this; +} + +ByteArray& ByteArray::operator=(ByteArray&& b) noexcept { + if (&b != this) { + reset(); + + m_data = take(b.m_data); + m_capacity = take(b.m_capacity); + m_size = take(b.m_size); + } + + return *this; +} + +void ByteArray::reset() { + if (m_data) { + Star::free(m_data, m_capacity); + m_data = nullptr; + m_capacity = 0; + m_size = 0; + } +} + +void ByteArray::reserve(size_t newCapacity) { + if (newCapacity > m_capacity) { + if (!m_data) { + auto newMem = (char*)Star::malloc(newCapacity); + if (!newMem) + throw MemoryException::format("Could not set new ByteArray capacity %s\n", newCapacity); + m_data = newMem; + m_capacity = newCapacity; + } else { + newCapacity = max({m_capacity * 2, newCapacity, (size_t)8}); + auto newMem = (char*)Star::realloc(m_data, newCapacity); + if (!newMem) + throw MemoryException::format("Could not set new ByteArray capacity %s\n", newCapacity); + m_data = newMem; + m_capacity = newCapacity; + } + } +} + +void ByteArray::resize(size_t size, char f) { + if (m_size == size) + return; + + size_t oldSize = m_size; + resize(size); + for (size_t i = oldSize; i < m_size; ++i) + (*this)[i] = f; +} + +void ByteArray::fill(size_t s, char c) { + if (s != NPos) + resize(s); + + memset(m_data, c, m_size); +} + +void ByteArray::fill(char c) { + fill(NPos, c); +} + +ByteArray ByteArray::sub(size_t b, size_t s) const { + if (b == 0 && s >= m_size) { + return ByteArray(*this); + } else { + return ByteArray(m_data + b, min(m_size, b + s)); + } +} + +ByteArray ByteArray::left(size_t s) const { + return sub(0, s); +} + +ByteArray ByteArray::right(size_t s) const { + if (s > m_size) + s = 0; + else + s = m_size - s; + + return sub(s, m_size); +} + +void ByteArray::trimLeft(size_t s) { + if (s >= m_size) { + clear(); + } else { + std::memmove(m_data, m_data + s, m_size - s); + resize(m_size - s); + } +} + +void ByteArray::trimRight(size_t s) { + if (s >= m_size) + clear(); + else + resize(m_size - s); +} + +size_t ByteArray::diffChar(const ByteArray& b) const { + size_t s = min(m_size, b.size()); + char* ac = m_data; + char* bc = b.m_data; + size_t i; + for (i = 0; i < s; ++i) { + if (ac[i] != bc[i]) + break; + } + + return i; +} + +int ByteArray::compare(const ByteArray& b) const { + if (m_size == 0 && b.m_size == 0) + return 0; + + if (m_size == 0) + return -1; + + if (b.m_size == 0) + return 1; + + size_t d = diffChar(b); + if (d == m_size) { + if (d != b.m_size) + return -1; + else + return 0; + } + + if (d == b.m_size) { + if (d != m_size) + return 1; + else + return 0; + } + + unsigned char c1 = (*this)[d]; + unsigned char c2 = b[d]; + + if (c1 < c2) { + return -1; + } else if (c1 > c2) { + return 1; + } else { + return 0; + } +} + +ByteArray ByteArray::andWith(ByteArray const& rhs, bool extend) { + return combineWith([](char a, char b) { return a & b; }, rhs, extend); +} + +ByteArray ByteArray::orWith(ByteArray const& rhs, bool extend) { + return combineWith([](char a, char b) { return a | b; }, rhs, extend); +} + +ByteArray ByteArray::xorWith(ByteArray const& rhs, bool extend) { + return combineWith([](char a, char b) { return a ^ b; }, rhs, extend); +} + +void ByteArray::insert(size_t pos, char byte) { + starAssert(pos <= m_size); + resize(m_size + 1); + for (size_t i = m_size - 1; i > pos; --i) + m_data[i] = m_data[i - 1]; + m_data[pos] = byte; +} + +ByteArray::iterator ByteArray::insert(const_iterator pos, char byte) { + size_t d = pos - begin(); + insert(d, byte); + return begin() + d + 1; +} + +void ByteArray::push_back(char byte) { + resize(m_size + 1); + m_data[m_size - 1] = byte; +} + +bool ByteArray::operator<(const ByteArray& b) const { + return compare(b) < 0; +} + +bool ByteArray::operator==(const ByteArray& b) const { + return compare(b) == 0; +} + +bool ByteArray::operator!=(const ByteArray& b) const { + return compare(b) != 0; +} + +std::ostream& operator<<(std::ostream& os, const ByteArray& b) { + os << "0x" << hexEncode(b); + return os; +} + +} diff --git a/source/core/StarByteArray.hpp b/source/core/StarByteArray.hpp new file mode 100644 index 0000000..051895f --- /dev/null +++ b/source/core/StarByteArray.hpp @@ -0,0 +1,259 @@ +#ifndef STAR_BYTE_ARRAY_H +#define STAR_BYTE_ARRAY_H + +#include "StarHash.hpp" +#include "StarException.hpp" +#include "StarFormat.hpp" + +namespace Star { + +STAR_CLASS(ByteArray); + +// Class to hold an array of bytes. Contains an internal buffer that may be +// larger than what is reported by size(), to avoid repeated allocations when a +// ByteArray grows. +class ByteArray { +public: + typedef char value_type; + typedef char* iterator; + typedef char const* const_iterator; + + // Constructs a byte array from a given c string WITHOUT including the + // trailing '\0' + static ByteArray fromCString(char const* str); + // Same, but includes the trailing '\0' + static ByteArray fromCStringWithNull(char const* str); + static ByteArray withReserve(size_t capacity); + + ByteArray(); + ByteArray(size_t dataSize, char c); + ByteArray(char const* data, size_t dataSize); + ByteArray(ByteArray const& b); + ByteArray(ByteArray&& b) noexcept; + ~ByteArray(); + + ByteArray& operator=(ByteArray const& b); + ByteArray& operator=(ByteArray&& b) noexcept; + + char const* ptr() const; + char* ptr(); + + size_t size() const; + // Maximum size before realloc + size_t capacity() const; + // Is zero size + bool empty() const; + + // Sets size to 0. + void clear(); + // Clears and resets buffer to empty. + void reset(); + + void reserve(size_t capacity); + + void resize(size_t size); + // resize, filling new space with given byte if it exists. + void resize(size_t size, char f); + + // fill array with byte. + void fill(char c); + // fill array and resize to new size. + void fill(size_t size, char c); + + void append(ByteArray const& b); + void append(char const* data, size_t len); + void appendByte(char b); + + void copyTo(char* data, size_t len) const; + void copyTo(char* data) const; + + // Copy from ByteArray starting at pos, to data, with size len. + void copyTo(char* data, size_t pos, size_t len) const; + // Copy from data pointer to ByteArray at pos with size len. + // Resizes if needed. + void writeFrom(char const* data, size_t pos, size_t len); + + ByteArray sub(size_t b, size_t s) const; + ByteArray left(size_t s) const; + ByteArray right(size_t s) const; + + void trimLeft(size_t s); + void trimRight(size_t s); + + // returns location of first character that is different than the given + // ByteArray. + size_t diffChar(ByteArray const& b) const; + // returns -1 if this < b, 0 if this == b, 1 if this > b + int compare(ByteArray const& b) const; + + template <typename Combiner> + ByteArray combineWith(Combiner&& combine, ByteArray const& rhs, bool extend = false); + + ByteArray andWith(ByteArray const& rhs, bool extend = false); + ByteArray orWith(ByteArray const& rhs, bool extend = false); + ByteArray xorWith(ByteArray const& rhs, bool extend = false); + + iterator begin(); + iterator end(); + + const_iterator begin() const; + const_iterator end() const; + + void insert(size_t pos, char byte); + iterator insert(const_iterator pos, char byte); + void push_back(char byte); + + char& operator[](size_t i); + char operator[](size_t i) const; + char at(size_t i) const; + + bool operator<(ByteArray const& b) const; + bool operator==(ByteArray const& b) const; + bool operator!=(ByteArray const& b) const; + +private: + char* m_data; + size_t m_capacity; + size_t m_size; +}; + +template <> +struct hash<ByteArray> { + size_t operator()(ByteArray const& b) const; +}; + +std::ostream& operator<<(std::ostream& os, ByteArray const& b); + +inline void ByteArray::clear() { + resize(0); +} + +inline void ByteArray::resize(size_t size) { + reserve(size); + m_size = size; +} + +inline void ByteArray::append(ByteArray const& b) { + append(b.ptr(), b.size()); +} + +inline void ByteArray::append(const char* data, size_t len) { + resize(m_size + len); + std::memcpy(m_data + m_size - len, data, len); +} + +inline void ByteArray::appendByte(char b) { + resize(m_size + 1); + m_data[m_size - 1] = b; +} + +inline bool ByteArray::empty() const { + return m_size == 0; +} + +inline char const* ByteArray::ptr() const { + return m_data; +} + +inline char* ByteArray::ptr() { + return m_data; +} + +inline size_t ByteArray::size() const { + return m_size; +} + +inline size_t ByteArray::capacity() const { + return m_capacity; +} + +inline void ByteArray::copyTo(char* data, size_t len) const { + len = min(m_size, len); + std::memcpy(data, m_data, len); +} + +inline void ByteArray::copyTo(char* data) const { + copyTo(data, m_size); +} + +inline void ByteArray::copyTo(char* data, size_t pos, size_t len) const { + if (len == 0 || pos >= m_size) + return; + + len = min(m_size - pos, len); + std::memcpy(data, m_data + pos, len); +} + +inline void ByteArray::writeFrom(const char* data, size_t pos, size_t len) { + if (pos + len > m_size) + resize(pos + len); + + std::memcpy(m_data + pos, data, len); +} + +template <typename Combiner> +ByteArray ByteArray::combineWith(Combiner&& combine, ByteArray const& rhs, bool extend) { + ByteArray const* smallerArray = &rhs; + ByteArray const* largerArray = this; + + if (m_size < rhs.size()) + swap(smallerArray, largerArray); + + ByteArray res; + res.resize(smallerArray->size()); + + for (size_t i = 0; i < smallerArray->size(); ++i) + res[i] = combine((*smallerArray)[i], (*largerArray)[i]); + + if (extend) { + res.resize(largerArray->size()); + for (size_t i = smallerArray->size(); i < largerArray->size(); ++i) + res[i] = (*largerArray)[i]; + } + + return res; +} + +inline ByteArray::iterator ByteArray::begin() { + return m_data; +} + +inline ByteArray::iterator ByteArray::end() { + return m_data + m_size; +} + +inline ByteArray::const_iterator ByteArray::begin() const { + return m_data; +} + +inline ByteArray::const_iterator ByteArray::end() const { + return m_data + m_size; +} + +inline char& ByteArray::operator[](size_t i) { + starAssert(i < m_size); + return m_data[i]; +} + +inline char ByteArray::operator[](size_t i) const { + starAssert(i < m_size); + return m_data[i]; +} + +inline char ByteArray::at(size_t i) const { + if (i >= m_size) + throw OutOfRangeException(strf("Out of range in ByteArray::at(%s)", i)); + + return m_data[i]; +} + +inline size_t hash<ByteArray>::operator()(ByteArray const& b) const { + PLHasher hash; + for (size_t i = 0; i < b.size(); ++i) + hash.put(b[i]); + return hash.hash(); +} + +} + +#endif diff --git a/source/core/StarBytes.hpp b/source/core/StarBytes.hpp new file mode 100644 index 0000000..3dd013f --- /dev/null +++ b/source/core/StarBytes.hpp @@ -0,0 +1,109 @@ +#ifndef STAR_BYTES_HPP +#define STAR_BYTES_HPP + +#include "StarMemory.hpp" + +namespace Star { + +enum class ByteOrder { + BigEndian, + LittleEndian, + NoConversion +}; + +ByteOrder platformByteOrder(); + +void swapByteOrder(void* ptr, size_t len); +void swapByteOrder(void* dest, void const* src, size_t len); + +void toByteOrder(ByteOrder order, void* ptr, size_t len); +void toByteOrder(ByteOrder order, void* dest, void const* src, size_t len); +void fromByteOrder(ByteOrder order, void* ptr, size_t len); +void fromByteOrder(ByteOrder order, void* dest, void const* src, size_t len); + +template <typename T> +T toByteOrder(ByteOrder order, T const& t) { + T ret; + toByteOrder(order, &ret, &t, sizeof(t)); + return ret; +} + +template <typename T> +T fromByteOrder(ByteOrder order, T const& t) { + T ret; + fromByteOrder(order, &ret, &t, sizeof(t)); + return ret; +} + +template <typename T> +T toBigEndian(T const& t) { + return toByteOrder(ByteOrder::BigEndian, t); +} + +template <typename T> +T fromBigEndian(T const& t) { + return fromByteOrder(ByteOrder::BigEndian, t); +} + +template <typename T> +T toLittleEndian(T const& t) { + return toByteOrder(ByteOrder::LittleEndian, t); +} + +template <typename T> +T fromLittleEndian(T const& t) { + return fromByteOrder(ByteOrder::LittleEndian, t); +} + +inline ByteOrder platformByteOrder() { +#if STAR_LITTLE_ENDIAN + return ByteOrder::LittleEndian; +#else + return ByteOrder::BigEndian; +#endif +} + +inline void swapByteOrder(void* ptr, size_t len) { + uint8_t* data = static_cast<uint8_t*>(ptr); + uint8_t spare; + for (size_t i = 0; i < len / 2; ++i) { + spare = data[len - 1 - i]; + data[len - 1 - i] = data[i]; + data[i] = spare; + } +} + +inline void swapByteOrder(void* dest, const void* src, size_t len) { + const uint8_t* srcdata = static_cast<const uint8_t*>(src); + uint8_t* destdata = static_cast<uint8_t*>(dest); + for (size_t i = 0; i < len; ++i) + destdata[len - 1 - i] = srcdata[i]; +} + +inline void toByteOrder(ByteOrder order, void* ptr, size_t len) { + if (order != ByteOrder::NoConversion && platformByteOrder() != order) + swapByteOrder(ptr, len); +} + +inline void toByteOrder(ByteOrder order, void* dest, void const* src, size_t len) { + if (order != ByteOrder::NoConversion && platformByteOrder() != order) + swapByteOrder(dest, src, len); + else + memcpy(dest, src, len); +} + +inline void fromByteOrder(ByteOrder order, void* ptr, size_t len) { + if (order != ByteOrder::NoConversion && platformByteOrder() != order) + swapByteOrder(ptr, len); +} + +inline void fromByteOrder(ByteOrder order, void* dest, void const* src, size_t len) { + if (order != ByteOrder::NoConversion && platformByteOrder() != order) + swapByteOrder(dest, src, len); + else + memcpy(dest, src, len); +} + +} + +#endif diff --git a/source/core/StarCasting.hpp b/source/core/StarCasting.hpp new file mode 100644 index 0000000..64305af --- /dev/null +++ b/source/core/StarCasting.hpp @@ -0,0 +1,93 @@ +#ifndef STAR_CASTING_HPP +#define STAR_CASTING_HPP + +#include "StarException.hpp" +#include "StarFormat.hpp" + +namespace Star { + +STAR_EXCEPTION(PointerConvertException, StarException); + +template <typename Type1, typename Type2> +bool is(Type2* p) { + return (bool)dynamic_cast<Type1*>(p); +} + +template <typename Type1, typename Type2> +bool is(Type2 const* p) { + return (bool)dynamic_cast<Type1 const*>(p); +} + +template <typename Type1, typename Type2> +bool is(shared_ptr<Type2> const& p) { + return (bool)dynamic_cast<Type1*>(p.get()); +} + +template <typename Type1, typename Type2> +bool is(shared_ptr<Type2 const> const& p) { + return (bool)dynamic_cast<Type1 const*>(p.get()); +} + +template <typename Type1, typename Type2> +bool ris(Type2& r) { + return (bool)dynamic_cast<Type1*>(&r); +} + +template <typename Type1, typename Type2> +bool ris(Type2 const& r) { + return (bool)dynamic_cast<Type1 const*>(&r); +} + +template <typename Type1, typename Type2> +Type1* as(Type2* p) { + return dynamic_cast<Type1*>(p); +} + +template <typename Type1, typename Type2> +Type1 const* as(Type2 const* p) { + return dynamic_cast<Type1 const*>(p); +} + +template <typename Type1, typename Type2> +shared_ptr<Type1> as(shared_ptr<Type2> const& p) { + return dynamic_pointer_cast<Type1>(p); +} + +template <typename Type1, typename Type2> +shared_ptr<Type1 const> as(shared_ptr<Type2 const> const& p) { + return dynamic_pointer_cast<Type1 const>(p); +} + +template <typename Type, typename Ptr> +auto convert(Ptr const& p) -> decltype(as<Type>(p)) { + if (!p) + throw PointerConvertException::format("Could not convert from nullptr to %s", typeid(Type).name()); + else if (auto a = as<Type>(p)) + return a; + else + throw PointerConvertException::format("Could not convert from %s to %s", typeid(*p).name(), typeid(Type).name()); +} + +template <typename Type1, typename Type2> +Type1& rconvert(Type2& r) { + return *dynamic_cast<Type1*>(&r); +} + +template <typename Type1, typename Type2> +Type1 const& rconvert(Type2 const& r) { + return *dynamic_cast<Type1 const*>(&r); +} + +template <typename Type> +weak_ptr<Type> asWeak(shared_ptr<Type> const& p) { + return weak_ptr<Type>(p); +} + +template <typename Type> +weak_ptr<Type const> asWeak(shared_ptr<Type const> const& p) { + return weak_ptr<Type>(p); +} + +} + +#endif diff --git a/source/core/StarColor.cpp b/source/core/StarColor.cpp new file mode 100644 index 0000000..0503f19 --- /dev/null +++ b/source/core/StarColor.cpp @@ -0,0 +1,627 @@ +#include "StarColor.hpp" +#include "StarMap.hpp" +#include "StarEncode.hpp" +#include "StarFormat.hpp" +#include "StarInterpolation.hpp" + +namespace Star { + +Color const Color::Red = Color::rgba(255, 73, 66, 255); +Color const Color::Orange = Color::rgba(255, 180, 47, 255); +Color const Color::Yellow = Color::rgba(255, 239, 30, 255); +Color const Color::Green = Color::rgba(79, 230, 70, 255); +Color const Color::Blue = Color::rgba(38, 96, 255, 255); +Color const Color::Indigo = Color::rgba(75, 0, 130, 255); +Color const Color::Violet = Color::rgba(160, 119, 255, 255); +Color const Color::Black = Color::rgba(0, 0, 0, 255); +Color const Color::White = Color::rgba(255, 255, 255, 255); +Color const Color::Magenta = Color::rgba(221, 92, 249, 255); +Color const Color::DarkMagenta = Color::rgba(142, 33, 144, 255); +Color const Color::Cyan = Color::rgba(0, 220, 233, 255); +Color const Color::DarkCyan = Color::rgba(0, 137, 165, 255); +Color const Color::CornFlowerBlue = Color::rgba(100, 149, 237, 255); +Color const Color::Gray = Color::rgba(160, 160, 160, 255); +Color const Color::LightGray = Color::rgba(192, 192, 192, 255); +Color const Color::DarkGray = Color::rgba(128, 128, 128, 255); +Color const Color::DarkGreen = Color::rgba(0, 128, 0, 255); +Color const Color::Pink = Color::rgba(255, 162, 187, 255); +Color const Color::Clear = Color::rgba(0, 0, 0, 0); + +CaseInsensitiveStringMap<Color> const Color::NamedColors{ + {"red", Color::Red}, + {"orange", Color::Orange}, + {"yellow", Color::Yellow}, + {"green", Color::Green}, + {"blue", Color::Blue}, + {"indigo", Color::Indigo}, + {"violet", Color::Violet}, + {"black", Color::Black}, + {"white", Color::White}, + {"magenta", Color::Magenta}, + {"darkmagenta", Color::DarkMagenta}, + {"cyan", Color::Cyan}, + {"darkcyan", Color::DarkCyan}, + {"cornflowerblue", Color::CornFlowerBlue}, + {"gray", Color::Gray}, + {"lightgray", Color::LightGray}, + {"darkgray", Color::DarkGray}, + {"darkgreen", Color::DarkGreen}, + {"pink", Color::Pink}, + {"clear", Color::Clear} +}; + +Color Color::rgbf(const Vec3F& c) { + return rgbaf(c[0], c[1], c[2], 1.0f); +} + +Color Color::rgbaf(const Vec4F& c) { + return rgbaf(c[0], c[1], c[2], c[3]); +} + +Color Color::rgbf(float r, float g, float b) { + return rgbaf(r, g, b, 1.0f); +} + +Color Color::rgbaf(float r, float g, float b, float a) { + Color c; + c.m_data[0] = clamp(r, 0.0f, 1.0f); + c.m_data[1] = clamp(g, 0.0f, 1.0f); + c.m_data[2] = clamp(b, 0.0f, 1.0f); + c.m_data[3] = clamp(a, 0.0f, 1.0f); + return c; +} + +Color Color::rgb(uint8_t r, uint8_t g, uint8_t b) { + return rgba(r, g, b, 255); +} + +Color Color::rgba(uint8_t r, uint8_t g, uint8_t b, uint8_t a) { + Color c; + c.m_data[0] = r / 255.0f; + c.m_data[1] = g / 255.0f; + c.m_data[2] = b / 255.0f; + c.m_data[3] = a / 255.0f; + return c; +} + +Color Color::fromUint32(uint32_t v) { + Color c; + c.setAlpha(((uint8_t*)(&v))[3]); + c.setRed(((uint8_t*)(&v))[2]); + c.setGreen(((uint8_t*)(&v))[1]); + c.setBlue(((uint8_t*)(&v))[0]); + return c; +} + +Color Color::temperature(float temp) { + // Magic numbers ahoy! + Color c; + c.setAlpha(255); + + temp = clamp<float>(temp, 1000, 40000); + + temp /= 100; + + double r, g, b; + if (temp <= 66) { + r = 255; + g = clamp<double>(99.4708025861 * log(temp) - 161.1195681661, 0, 255); + if (temp <= 19) { + b = 0; + } else { + b = clamp<double>(138.5177312231 * log(temp - 10) - 305.0447927307, 0, 255); + } + } else { + r = clamp<double>(329.698727446 * pow(temp - 60, -0.1332047592), 0, 255); + g = clamp<double>(288.1221695283 * pow(temp - 60, -0.0755148492), 0, 255); + b = 255; + } + + c.setRedF((float)r / 255.0f); + c.setGreenF((float)g / 255.0f); + c.setBlueF((float)b / 255.0f); + + return c; +} + +Color Color::rgb(Vec3B const& c) { + return rgb(c[0], c[1], c[2]); +} + +Color Color::rgba(Vec4B const& c) { + return rgba(c[0], c[1], c[2], c[3]); +} + +Color Color::hsv(float h, float s, float v) { + return hsva(h, s, v, 1.0f); +} + +Color Color::hsva(float h, float s, float v, float a) { + h = clamp(h, 0.0f, 1.0f); + s = clamp(s, 0.0f, 1.0f); + v = clamp(v, 0.0f, 1.0f); + a = clamp(a, 0.0f, 1.0f); + + Color retColor; + if (s == 0.0f) { + retColor.setRedF(v); + retColor.setGreenF(v); + retColor.setBlueF(v); + retColor.setAlphaF(a); + } else { + float var_h, var_i, var_1, var_2, var_3, var_r, var_g, var_b; + + var_h = h * 6.0f; + if (var_h == 6.0f) + var_h = 0.0f; // H must be < 1 + + var_i = floor(var_h); + + var_1 = v * (1.0f - s); + var_2 = v * (1.0f - s * (var_h - var_i)); + var_3 = v * (1.0f - s * (1.0f - (var_h - var_i))); + + if (var_i == 0) { + var_r = v; + var_g = var_3; + var_b = var_1; + } else if (var_i == 1) { + var_r = var_2; + var_g = v; + var_b = var_1; + } else if (var_i == 2) { + var_r = var_1; + var_g = v; + var_b = var_3; + } else if (var_i == 3) { + var_r = var_1; + var_g = var_2; + var_b = v; + } else if (var_i == 4) { + var_r = var_3; + var_g = var_1; + var_b = v; + } else { + var_r = v; + var_g = var_1; + var_b = var_2; + } + + retColor.setRedF(var_r); + retColor.setGreenF(var_g); + retColor.setBlueF(var_b); + retColor.setAlphaF(a); + } + return retColor; +} + +Color Color::hsv(Vec3F const& c) { + return Color::hsv(c[0], c[1], c[2]); +} + +Color Color::hsva(Vec4F const& c) { + return Color::hsva(c[0], c[1], c[2], c[3]); +} + +Color Color::grayf(float g) { + return Color::rgbf(g, g, g); +} + +Color Color::gray(uint8_t g) { + return Color::rgb(g, g, g); +} + +Color::Color() {} + +Color::Color(const String& name) { + if (name.beginsWith("#")) + *this = fromHex(name.substr(1)); + else { + auto i = NamedColors.find(name.toLower()); + if (i != NamedColors.end()) + *this = i->second; + else + throw ColorException(strf("Named color %s not found", name)); + } +} + +float Color::redF() const { + return m_data[0]; +} + +float Color::greenF() const { + return m_data[1]; +} + +float Color::blueF() const { + return m_data[2]; +} + +float Color::alphaF() const { + return m_data[3]; +} + +bool Color::isClear() const { + return m_data[3] == 0; +} + +uint8_t Color::red() const { + return uint8_t(round(m_data[0] * 255)); +} + +uint8_t Color::green() const { + return uint8_t(round(m_data[1] * 255)); +} + +uint8_t Color::blue() const { + return uint8_t(round(m_data[2] * 255)); +} + +uint8_t Color::alpha() const { + return uint8_t(m_data[3] * 255); +} + +void Color::setRedF(float r) { + m_data[0] = clamp(r, 0.0f, 1.0f); +} + +void Color::setGreenF(float g) { + m_data[1] = clamp(g, 0.0f, 1.0f); +} + +void Color::setBlueF(float b) { + m_data[2] = clamp(b, 0.0f, 1.0f); +} + +void Color::setAlphaF(float a) { + m_data[3] = clamp(a, 0.0f, 1.0f); +} + +void Color::setRed(uint8_t r) { + m_data[0] = r / 255.0f; +} + +void Color::setGreen(uint8_t g) { + m_data[1] = g / 255.0f; +} + +void Color::setBlue(uint8_t b) { + m_data[2] = b / 255.0f; +} + +void Color::setAlpha(uint8_t a) { + m_data[3] = a / 255.0f; +} + +uint32_t Color::toUint32() const { + uint32_t val; + ((uint8_t*)(&val))[3] = alpha(); + ((uint8_t*)(&val))[2] = red(); + ((uint8_t*)(&val))[1] = green(); + ((uint8_t*)(&val))[0] = blue(); + return val; +} + +Color Color::fromHex(String const& s) { + uint8_t cbytes[4]; + + if (s.utf8Size() == 3) { + nibbleDecode(s.utf8Ptr(), 3, (char*)cbytes, 4); + cbytes[0] = (cbytes[0] << 4) | cbytes[0]; + cbytes[1] = (cbytes[1] << 4) | cbytes[1]; + cbytes[2] = (cbytes[2] << 4) | cbytes[2]; + cbytes[3] = 255; + } else if (s.utf8Size() == 4) { + nibbleDecode(s.utf8Ptr(), 4, (char*)cbytes, 4); + cbytes[0] = (cbytes[0] << 4) | cbytes[0]; + cbytes[1] = (cbytes[1] << 4) | cbytes[1]; + cbytes[2] = (cbytes[2] << 4) | cbytes[2]; + cbytes[3] = (cbytes[3] << 4) | cbytes[3]; + } else if (s.utf8Size() == 6) { + hexDecode(s.utf8Ptr(), 6, (char*)cbytes, 4); + cbytes[3] = 255; + } else if (s.utf8Size() == 8) { + hexDecode(s.utf8Ptr(), 8, (char*)cbytes, 4); + } else { + throw ColorException(strf("Improper size for hex string '%s' in Color::fromHex", s)); + } + + return Color::rgba(cbytes[0], cbytes[1], cbytes[2], cbytes[3]); +} + +Vec4B Color::toRgba() const { + return Vec4B(red(), green(), blue(), alpha()); +} + +Vec3B Color::toRgb() const { + return Vec3B(red(), green(), blue()); +} + +Vec4F Color::toRgbaF() const { + return Vec4F(redF(), greenF(), blueF(), alphaF()); +} + +Vec3F Color::toRgbF() const { + return Vec3F(redF(), greenF(), blueF()); +} + +Vec4F Color::toHsva() const { + float h, s, v; + + float var_r = redF(); + float var_g = greenF(); + float var_b = blueF(); + + // Min. value of RGB + float var_min = min(min(var_r, var_g), var_b); + + // Max. value of RGB + float var_max = max(max(var_r, var_g), var_b); + + // Delta RGB value + float del_max = var_max - var_min; + + v = var_max; + + if (del_max == 0.0f) { // This is a gray, no chroma... + h = 0.0f; + s = 0.0f; + } else { // Chromatic data + s = del_max / var_max; + + float del_r = (((var_max - var_r) / 6.0f) + (del_max / 2.0f)) / del_max; + float del_g = (((var_max - var_g) / 6.0f) + (del_max / 2.0f)) / del_max; + float del_b = (((var_max - var_b) / 6.0f) + (del_max / 2.0f)) / del_max; + + if (var_r == var_max) + h = del_b - del_g; + else if (var_g == var_max) + h = (1.0f / 3.0f) + del_r - del_b; + else + /*if (var_b == var_max)*/ h = (2.0f / 3.0f) + del_g - del_r; + + if (h < 0.0f) + h += 1.0f; + if (h >= 1.0f) + h -= 1.0f; + } + + return Vec4F(h, s, v, alphaF()); +} + +String Color::toHex() const { + auto rgba = toRgba(); + return hexEncode((char*)rgba.ptr(), rgba[3] == 255 ? 3 : 4); +} + +float Color::hue() const { + return toHsva()[0]; +} + +float Color::saturation() const { + // Min. value of RGB + float var_min = min(min(m_data[0], m_data[1]), m_data[2]); + + // Max. value of RGB + float var_max = max(max(m_data[0], m_data[1]), m_data[2]); + + // Delta RGB value + float del_max = var_max - var_min; + + if (del_max == 0.0f) { // This is a gray, no chroma... + return 0.0f; + } else + return del_max / var_max; +} + +float Color::value() const { + return max(max(m_data[0], m_data[1]), m_data[2]); +} + +void Color::setHue(float h) { + auto hsva = toHsva(); + *this = Color::hsva(clamp(h, 0.0f, 1.0f), hsva[1], hsva[2], alphaF()); +} + +void Color::setSaturation(float s) { + auto hsva = toHsva(); + *this = Color::hsva(hsva[0], clamp(s, 0.0f, 1.0f), hsva[2], alphaF()); +} + +void Color::setValue(float v) { + auto hsva = toHsva(); + *this = Color::hsva(hsva[0], hsva[1], clamp(v, 0.0f, 1.0f), alphaF()); +} + +void Color::hueShift(float h) { + setHue(pfmod(hue() + h, 1.0f)); +} + +void Color::fade(float value) { + m_data *= (1.0f - value); + m_data.clamp(0.0f, 1.0f); +} + +bool Color::operator==(const Color& c) const { + return m_data == c.m_data; +} + +bool Color::operator!=(const Color& c) const { + return m_data != c.m_data; +} + +std::ostream& operator<<(std::ostream& os, const Color& c) { + os << c.toRgbaF(); + return os; +} + +float Color::toLinear(float in) { + const float a = 0.055f; + if (in <= 0.04045f) + return in / 12.92f; + return powf((in + a) / (1.0f + a), 2.4f); +} + +float Color::fromLinear(float in) { + const float a = 0.055f; + if (in <= 0.0031308f) + return 12.92f * in; + return (1.0f + a) * powf(in, 1.0f / 2.4f) - a; +} + +void Color::convertToLinear() { + setRedF(toLinear(redF())); + setGreenF(toLinear(greenF())); + setBlueF(toLinear(blueF())); +} + +void Color::convertToSRGB() { + setRedF(fromLinear(redF())); + setGreenF(fromLinear(greenF())); + setBlueF(fromLinear(blueF())); +} + +Color Color::toLinear() { + Color c = *this; + c.convertToLinear(); + return c; +} + +Color Color::toSRGB() { + Color c = *this; + c.convertToSRGB(); + return c; +} + +Color Color::contrasting() { + Color c = *this; + c.setHue(c.hue() + 120); + return c; +} + +Color Color::complementary() { + Color c = *this; + c.setHue(c.hue() + 180); + return c; +} + +Color Color::mix(Color const& c, float amount) const { + return Color::rgbaf(lerp(clamp(amount, 0.0f, 1.0f), toRgbaF(), c.toRgbaF())); +} + +Color Color::multiply(float amount) const { + return Color::rgbaf(m_data * amount); +} + +Color Color::operator+(Color const& c) const { + return Color::rgbaf(m_data + c.toRgbaF()); +} + +Color Color::operator*(Color const& c) const { + return Color::rgbaf(m_data.piecewiseMultiply(c.toRgbaF())); +} + +Color& Color::operator+=(Color const& c) { + return * this = *this + c; +} + +Color& Color::operator*=(Color const& c) { + return * this = *this * c; +} + +Vec4B Color::hueShiftVec4B(Vec4B color, float hue) { + float h, s, v; + + float var_r = color[0] / 255.0f; + float var_g = color[1] / 255.0f; + float var_b = color[2] / 255.0f; + + // Min. value of RGB + float var_min = min(min(var_r, var_g), var_b); + + // Max. value of RGB + float var_max = max(max(var_r, var_g), var_b); + + // Delta RGB value + float del_max = var_max - var_min; + + v = var_max; + + if (del_max == 0.0f) { // This is a gray, no chroma... + h = 0.0f; + s = 0.0f; + } else { // Chromatic data + s = del_max / var_max; + + float vd = 1.0f / 6.0f; + float dmh = del_max * 0.5f; + float dmi = 1.0f / del_max; + float del_r = (((var_max - var_r) * vd) + dmh) * dmi; + float del_g = (((var_max - var_g) * vd) + dmh) * dmi; + float del_b = (((var_max - var_b) * vd) + dmh) * dmi; + + if (var_r == var_max) + h = del_b - del_g; + else if (var_g == var_max) + h = (1.0f / 3.0f) + del_r - del_b; + else + h = (2.0f / 3.0f) + del_g - del_r; + + if (h < 0.0f) + h += 1.0f; + if (h >= 1.0f) + h -= 1.0f; + } + + h += hue; + + if (h >= 1.0f) + h -= 1.0f; + + if (s == 0.0f) { + auto c = uint8_t(round(v * 255)); + return Vec4B(c, c, c, color[3]); + } else { + float var_h, var_i, var_1, var_2, var_3, var_r, var_g, var_b; + + var_h = h * 6.0f; + if (var_h == 6.0f) + var_h = 0.0f; // H must be < 1 + + var_i = floor(var_h); + + var_1 = v * (1.0f - s); + var_2 = v * (1.0f - s * (var_h - var_i)); + var_3 = v * (1.0f - s * (1.0f - (var_h - var_i))); + + if (var_i == 0) { + var_r = v; + var_g = var_3; + var_b = var_1; + } else if (var_i == 1) { + var_r = var_2; + var_g = v; + var_b = var_1; + } else if (var_i == 2) { + var_r = var_1; + var_g = v; + var_b = var_3; + } else if (var_i == 3) { + var_r = var_1; + var_g = var_2; + var_b = v; + } else if (var_i == 4) { + var_r = var_3; + var_g = var_1; + var_b = v; + } else { + var_r = v; + var_g = var_1; + var_b = var_2; + } + + return Vec4B(uint8_t(round(var_r * 255)), uint8_t(round(var_g * 255)), uint8_t(round(var_b * 255)), color[3]); + } +} + +} diff --git a/source/core/StarColor.hpp b/source/core/StarColor.hpp new file mode 100644 index 0000000..33f3374 --- /dev/null +++ b/source/core/StarColor.hpp @@ -0,0 +1,171 @@ +#ifndef STAR_COLOR_HPP +#define STAR_COLOR_HPP + +#include "StarString.hpp" +#include "StarVector.hpp" + +namespace Star { + +STAR_EXCEPTION(ColorException, StarException); + +class Color { +public: + static Color const Red; + static Color const Orange; + static Color const Yellow; + static Color const Green; + static Color const Blue; + static Color const Indigo; + static Color const Violet; + static Color const Black; + static Color const White; + static Color const Magenta; + static Color const DarkMagenta; + static Color const Cyan; + static Color const DarkCyan; + static Color const CornFlowerBlue; + static Color const Gray; + static Color const LightGray; + static Color const DarkGray; + static Color const DarkGreen; + static Color const Pink; + static Color const Clear; + + static CaseInsensitiveStringMap<Color> const NamedColors; + + // Some useful conversion methods for dealing with Vec3 / Vec4 as colors + static Vec3F v3bToFloat(Vec3B const& b); + static Vec3B v3fToByte(Vec3F const& f, bool doClamp = true); + static Vec4F v4bToFloat(Vec4B const& b); + static Vec4B v4fToByte(Vec4F const& f, bool doClamp = true); + + static Color rgbf(float r, float g, float b); + static Color rgbaf(float r, float g, float b, float a); + static Color rgbf(Vec3F const& c); + static Color rgbaf(Vec4F const& c); + + static Color rgb(uint8_t r, uint8_t g, uint8_t b); + static Color rgba(uint8_t r, uint8_t g, uint8_t b, uint8_t a); + static Color rgb(Vec3B const& c); + static Color rgba(Vec4B const& c); + + static Color hsv(float h, float s, float b); + static Color hsva(float h, float s, float b, float a); + static Color hsv(Vec3F const& c); + static Color hsva(Vec4F const& c); + + static Color grayf(float g); + static Color gray(uint8_t g); + + // Only supports 8 bit color + static Color fromHex(String const& s); + + // #AARRGGBB + static Color fromUint32(uint32_t v); + + // Color from temperature in Kelvin + static Color temperature(float temp); + + static Vec4B hueShiftVec4B(Vec4B color, float hue); + + // Black + Color(); + + explicit Color(String const& name); + + uint8_t red() const; + uint8_t green() const; + uint8_t blue() const; + uint8_t alpha() const; + + void setRed(uint8_t r); + void setGreen(uint8_t g); + void setBlue(uint8_t b); + void setAlpha(uint8_t a); + + float redF() const; + float greenF() const; + float blueF() const; + float alphaF() const; + + void setRedF(float r); + void setGreenF(float b); + void setBlueF(float g); + void setAlphaF(float a); + + bool isClear() const; + + // Returns a 4 byte value equal to #AARRGGBB + uint32_t toUint32() const; + + Vec4B toRgba() const; + Vec3B toRgb() const; + Vec4F toRgbaF() const; + Vec3F toRgbF() const; + + Vec4F toHsva() const; + + String toHex() const; + + float hue() const; + float saturation() const; + float value() const; + + void setHue(float hue); + void setSaturation(float saturation); + void setValue(float value); + + // Shift the current hue by the given value, with hue wrapping. + void hueShift(float hue); + + // Reduce the color toward black by the given amount, from 0.0 to 1.0. + void fade(float value); + + void convertToLinear(); + void convertToSRGB(); + + Color toLinear(); + Color toSRGB(); + + Color contrasting(); + Color complementary(); + + // Mix two colors, giving the second color the given amount + Color mix(Color const& c, float amount = 0.5f) const; + Color multiply(float amount) const; + + bool operator==(Color const& c) const; + bool operator!=(Color const& c) const; + Color operator+(Color const& c) const; + Color operator*(Color const& c) const; + Color& operator+=(Color const& c); + Color& operator*=(Color const& c); + +private: + static float toLinear(float in); + static float fromLinear(float in); + + Vec4F m_data; +}; + +std::ostream& operator<<(std::ostream& os, Color const& c); + +inline Vec3F Color::v3bToFloat(Vec3B const& b) { + return Vec3F(byteToFloat(b[0]), byteToFloat(b[1]), byteToFloat(b[2])); +} + +inline Vec3B Color::v3fToByte(Vec3F const& f, bool doClamp) { + return Vec3B(floatToByte(f[0], doClamp), floatToByte(f[1], doClamp), floatToByte(f[2], doClamp)); +} + +inline Vec4F Color::v4bToFloat(Vec4B const& b) { + return Vec4F(byteToFloat(b[0]), byteToFloat(b[1]), byteToFloat(b[2]), byteToFloat(b[3])); +} + +inline Vec4B Color::v4fToByte(Vec4F const& f, bool doClamp) { + return Vec4B(floatToByte(f[0], doClamp), floatToByte(f[1], doClamp), floatToByte(f[2], doClamp), floatToByte(f[3], doClamp)); +} + +} + +#endif diff --git a/source/core/StarCompression.cpp b/source/core/StarCompression.cpp new file mode 100644 index 0000000..e8d282c --- /dev/null +++ b/source/core/StarCompression.cpp @@ -0,0 +1,223 @@ +#include "StarCompression.hpp" +#include "StarFormat.hpp" +#include "StarLexicalCast.hpp" + +#include <zlib.h> +#include <errno.h> +#include <string.h> + +namespace Star { + +void compressData(ByteArray const& in, ByteArray& out, CompressionLevel compression) { + out.clear(); + + if (in.empty()) + return; + + const size_t BUFSIZE = 32 * 1024; + unsigned char temp_buffer[BUFSIZE]; + + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + int deflate_res = deflateInit(&strm, compression); + if (deflate_res != Z_OK) + throw IOException(strf("Failed to initialise deflate (%d)", deflate_res)); + + strm.next_in = (unsigned char*)in.ptr(); + strm.avail_in = in.size(); + strm.next_out = temp_buffer; + strm.avail_out = BUFSIZE; + while (deflate_res == Z_OK) { + deflate_res = deflate(&strm, Z_FINISH); + if (strm.avail_out == 0) { + out.append((char const*)temp_buffer, BUFSIZE); + strm.next_out = temp_buffer; + strm.avail_out = BUFSIZE; + } + } + deflateEnd(&strm); + + if (deflate_res != Z_STREAM_END) + throw IOException(strf("Internal error in uncompressData, deflate_res is %s", deflate_res)); + + out.append((char const*)temp_buffer, BUFSIZE - strm.avail_out); +} + +ByteArray compressData(ByteArray const& in, CompressionLevel compression) { + ByteArray out = ByteArray::withReserve(in.size()); + compressData(in, out, compression); + return out; +} + +void uncompressData(ByteArray const& in, ByteArray& out) { + out.clear(); + + if (in.empty()) + return; + + const size_t BUFSIZE = 32 * 1024; + unsigned char temp_buffer[BUFSIZE]; + + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + int inflate_res = inflateInit(&strm); + if (inflate_res != Z_OK) + throw IOException(strf("Failed to initialise inflate (%d)", inflate_res)); + + strm.next_in = (unsigned char*)in.ptr(); + strm.avail_in = in.size(); + strm.next_out = temp_buffer; + strm.avail_out = BUFSIZE; + + while (inflate_res == Z_OK || inflate_res == Z_BUF_ERROR) { + inflate_res = inflate(&strm, Z_FINISH); + if (strm.avail_out == 0) { + out.append((char const*)temp_buffer, BUFSIZE); + strm.next_out = temp_buffer; + strm.avail_out = BUFSIZE; + } else if (inflate_res == Z_BUF_ERROR) { + break; + } + } + inflateEnd(&strm); + + if (inflate_res != Z_STREAM_END) + throw IOException(strf("Internal error in uncompressData, inflate_res is %s", inflate_res)); + + out.append((char const*)temp_buffer, BUFSIZE - strm.avail_out); +} + +ByteArray uncompressData(ByteArray const& in) { + ByteArray out = ByteArray::withReserve(in.size()); + uncompressData(in, out); + return out; +} + +CompressedFilePtr CompressedFile::open(String const& filename, IOMode mode, CompressionLevel comp) { + CompressedFilePtr f = make_shared<CompressedFile>(filename); + f->open(mode, comp); + return f; +} + +CompressedFile::CompressedFile() + : IODevice(IOMode::Closed), m_file(0), m_compression(MediumCompression) {} + +CompressedFile::CompressedFile(String filename) + : IODevice(IOMode::Closed), m_file(0), m_compression(MediumCompression) { + setFilename(move(filename)); +} + +CompressedFile::~CompressedFile() { + close(); +} + +StreamOffset CompressedFile::pos() { + return gztell((gzFile)m_file); +} + +void CompressedFile::seek(StreamOffset offset, IOSeek seekMode) { + StreamOffset begPos = pos(); + + int retCode; + if (seekMode == IOSeek::Relative) { + retCode = gzseek((gzFile)m_file, (z_off_t)offset, SEEK_CUR); + } else if (seekMode == IOSeek::Absolute) { + retCode = gzseek((gzFile)m_file, (z_off_t)offset, SEEK_SET); + } else { + throw IOException("Cannot seek with SeekEnd in compressed file"); + } + + StreamOffset endPos = pos(); + + if (retCode < 0) { + throw IOException::format("Seek error: %s", gzerror((gzFile)m_file, 0)); + } else if ((seekMode == IOSeek::Relative && begPos + offset != endPos) + || (seekMode == IOSeek::Absolute && offset != endPos)) { + throw EofException("Error, unexpected end of file found"); + } +} + +bool CompressedFile::atEnd() { + return gzeof((gzFile)m_file); +} + +size_t CompressedFile::read(char* data, size_t len) { + if (len == 0) + return 0; + + int ret = gzread((gzFile)m_file, data, len); + if (ret == 0) + throw EofException("Error, unexpected end of file found"); + else if (ret == -1) + throw IOException::format("Read error: %s", gzerror((gzFile)m_file, 0)); + else + return (size_t)ret; +} + +size_t CompressedFile::write(const char* data, size_t len) { + if (len == 0) + return 0; + + int ret = gzwrite((gzFile)m_file, data, len); + if (ret == 0) + throw IOException::format("Write error: %s", gzerror((gzFile)m_file, 0)); + else + return (size_t)ret; +} + +void CompressedFile::setFilename(String filename) { + if (isOpen()) + throw IOException("Cannot call setFilename while CompressedFile is open"); + m_filename = move(filename); +} + +void CompressedFile::setCompression(CompressionLevel compression) { + if (isOpen()) + throw IOException("Cannot call setCompression while CompressedFile is open"); + m_compression = compression; +} + +void CompressedFile::open(IOMode mode, CompressionLevel compression) { + close(); + setCompression(compression); + open(mode); +} + +void CompressedFile::sync() { + gzflush((gzFile)m_file, Z_FULL_FLUSH); +} + +void CompressedFile::open(IOMode mode) { + setMode(mode); + String modeString; + + if (mode & IOMode::Append) { + throw IOException("CompressedFile not compatible with Append mode"); + } else if ((mode & IOMode::Read) && (mode & IOMode::Write)) { + throw IOException("CompressedFile not compatible with ReadWrite mode"); + } else if (mode & IOMode::Write) { + modeString = "wb"; + } else if (mode & IOMode::Read) { + modeString = "rb"; + } + + modeString += toString(m_compression); + + m_file = gzopen(m_filename.utf8Ptr(), modeString.utf8Ptr()); + + if (!m_file) + throw IOException::format("Cannot open filename '%s'", m_filename); +} + +void CompressedFile::close() { + if (m_file) + gzclose((gzFile)m_file); + m_file = 0; + setMode(IOMode::Closed); +} + +} diff --git a/source/core/StarCompression.hpp b/source/core/StarCompression.hpp new file mode 100644 index 0000000..ba36dc7 --- /dev/null +++ b/source/core/StarCompression.hpp @@ -0,0 +1,59 @@ +#ifndef STAR_COMPRESSION_HPP +#define STAR_COMPRESSION_HPP + +#include "StarIODevice.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_CLASS(CompressedFile); + +// Zlib compression level, ranges from 0 to 9 +typedef int CompressionLevel; + +CompressionLevel const LowCompression = 2; +CompressionLevel const MediumCompression = 5; +CompressionLevel const HighCompression = 9; + +void compressData(ByteArray const& in, ByteArray& out, CompressionLevel compression = MediumCompression); +ByteArray compressData(ByteArray const& in, CompressionLevel compression = MediumCompression); + +void uncompressData(ByteArray const& in, ByteArray& out); +ByteArray uncompressData(ByteArray const& in); + +// Random access to a (potentially) compressed file. +class CompressedFile : public IODevice { +public: + static CompressedFilePtr open(String const& filename, IOMode mode, CompressionLevel comp = MediumCompression); + + CompressedFile(); + CompressedFile(String filename); + virtual ~CompressedFile(); + + void setFilename(String filename); + void setCompression(CompressionLevel compression); + + StreamOffset pos() override; + // Only seek forward is supported on writes. Seek is emulated *slowly* on + // reads. + void seek(StreamOffset pos, IOSeek seek = IOSeek::Absolute) override; + bool atEnd() override; + size_t read(char* data, size_t len) override; + size_t write(char const* data, size_t len) override; + + void open(IOMode mode) override; + // Compression is ignored on read. Always truncates on write + void open(IOMode mode, CompressionLevel compression); + + void sync() override; + void close() override; + +private: + String m_filename; + void* m_file; + CompressionLevel m_compression; +}; + +} + +#endif diff --git a/source/core/StarConfig.hpp b/source/core/StarConfig.hpp new file mode 100644 index 0000000..22c2607 --- /dev/null +++ b/source/core/StarConfig.hpp @@ -0,0 +1,113 @@ +#ifndef STAR_CONFIG_HPP +#define STAR_CONFIG_HPP + +#include <cstdint> +#include <cstdlib> +#include <cstddef> +#include <cstring> +#include <cmath> +#include <tuple> +#include <memory> +#include <functional> +#include <algorithm> +#include <iostream> +#include <initializer_list> +#include <exception> +#include <stdexcept> +#include <atomic> +#include <string> +#include <iterator> + +namespace Star { + +// Some really common std namespace includes + +using std::size_t; + +using std::swap; +using std::move; + +using std::unique_ptr; +using std::shared_ptr; +using std::weak_ptr; +using std::make_shared; +using std::make_unique; +using std::static_pointer_cast; +using std::dynamic_pointer_cast; +using std::const_pointer_cast; +using std::enable_shared_from_this; + +using std::pair; +using std::make_pair; + +using std::tuple; +using std::make_tuple; +using std::tuple_element; +using std::get; +using std::tie; +using std::ignore; + +using std::initializer_list; + +using std::min; +using std::max; + +using std::bind; +using std::function; +using std::forward; +using std::mem_fn; +using std::ref; +using std::cref; +using namespace std::placeholders; + +using std::prev; +// using std::next; + +using std::atomic; +using std::atomic_flag; +using std::atomic_load; +using std::atomic_store; + +#ifndef NDEBUG +#define STAR_DEBUG 1 +constexpr bool DebugEnabled = true; +#else +constexpr bool DebugEnabled = false; +#endif + +// A version of string::npos that's used in general to mean "not a position" +// and is the largest value for size_t. +size_t const NPos = (size_t)(-1); + +typedef int64_t StreamOffset; + +// Convenient way to purposefully mark a variable as unused to avoid warning +#define _unused(x) ((void)x) + +// Forward declare a class or struct, and define a lot of typedefs for +// different pointer types all at once. + +#define STAR_CLASS(ClassName) \ + class ClassName; \ + typedef std::shared_ptr<ClassName> ClassName##Ptr; \ + typedef std::shared_ptr<const ClassName> ClassName##ConstPtr; \ + typedef std::weak_ptr<ClassName> ClassName##WeakPtr; \ + typedef std::weak_ptr<const ClassName> ClassName##ConstWeakPtr; \ + typedef std::unique_ptr<ClassName> ClassName##UPtr; \ + typedef std::unique_ptr<const ClassName> ClassName##ConstUPtr + +#define STAR_STRUCT(StructName) \ + struct StructName; \ + typedef std::shared_ptr<StructName> StructName##Ptr; \ + typedef std::shared_ptr<const StructName> StructName##ConstPtr; \ + typedef std::weak_ptr<StructName> StructName##WeakPtr; \ + typedef std::weak_ptr<const StructName> StructName##ConstWeakPtr; \ + typedef std::unique_ptr<StructName> StructName##UPtr; \ + typedef std::unique_ptr<const StructName> StructName##ConstUPtr + +#define STAR_QUOTE(name) #name +#define STAR_STR(macro) STAR_QUOTE(macro) + +} + +#endif diff --git a/source/core/StarDataStream.cpp b/source/core/StarDataStream.cpp new file mode 100644 index 0000000..c8d50a4 --- /dev/null +++ b/source/core/StarDataStream.cpp @@ -0,0 +1,311 @@ +#include "StarDataStream.hpp" +#include "StarBytes.hpp" +#include "StarVlqEncoding.hpp" + +#include <string.h> + +namespace Star { + +DataStream::DataStream() + : m_byteOrder(ByteOrder::BigEndian), + m_nullTerminatedStrings(false), + m_streamCompatibilityVersion(CurrentStreamVersion) {} + +ByteOrder DataStream::byteOrder() const { + return m_byteOrder; +} + +void DataStream::setByteOrder(ByteOrder byteOrder) { + m_byteOrder = byteOrder; +} + +bool DataStream::nullTerminatedStrings() const { + return m_nullTerminatedStrings; +} + +void DataStream::setNullTerminatedStrings(bool nullTerminatedStrings) { + m_nullTerminatedStrings = nullTerminatedStrings; +} + +unsigned DataStream::streamCompatibilityVersion() const { + return m_streamCompatibilityVersion; +} + +void DataStream::setStreamCompatibilityVersion(unsigned streamCompatibilityVersion) { + m_streamCompatibilityVersion = streamCompatibilityVersion; +} + +ByteArray DataStream::readBytes(size_t len) { + ByteArray ba; + ba.resize(len); + readData(ba.ptr(), len); + return ba; +} + +void DataStream::writeBytes(ByteArray const& ba) { + writeData(ba.ptr(), ba.size()); +} + +DataStream& DataStream::operator<<(bool d) { + operator<<((uint8_t)d); + return *this; +} + +DataStream& DataStream::operator<<(char c) { + writeData(&c, 1); + return *this; +} + +DataStream& DataStream::operator<<(int8_t d) { + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(uint8_t d) { + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(int16_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(uint16_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(int32_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(uint32_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(int64_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(uint64_t d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(float d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator<<(double d) { + d = toByteOrder(m_byteOrder, d); + writeData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator>>(bool& d) { + uint8_t bu; + readData((char*)&bu, sizeof(bu)); + d = (bool)bu; + return *this; +} + +DataStream& DataStream::operator>>(char& c) { + readData(&c, 1); + return *this; +} + +DataStream& DataStream::operator>>(int8_t& d) { + readData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator>>(uint8_t& d) { + readData((char*)&d, sizeof(d)); + return *this; +} + +DataStream& DataStream::operator>>(int16_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(uint16_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(int32_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(uint32_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(int64_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(uint64_t& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(float& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +DataStream& DataStream::operator>>(double& d) { + readData((char*)&d, sizeof(d)); + d = fromByteOrder(m_byteOrder, d); + return *this; +} + +size_t DataStream::writeVlqU(uint64_t i) { + return Star::writeVlqU(i, makeFunctionOutputIterator([this](uint8_t b) { *this << b; })); +} + +size_t DataStream::writeVlqI(int64_t i) { + return Star::writeVlqI(i, makeFunctionOutputIterator([this](uint8_t b) { *this << b; })); +} + +size_t DataStream::writeVlqS(size_t i) { + uint64_t i64; + if (i == NPos) + i64 = 0; + else + i64 = i + 1; + return writeVlqU(i64); +} + +size_t DataStream::readVlqU(uint64_t& i) { + size_t bytesRead = Star::readVlqU(i, makeFunctionInputIterator([this]() { return this->read<uint8_t>(); })); + + if (bytesRead == NPos) + throw DataStreamException("Error reading VLQ encoded intenger!"); + + return bytesRead; +} + +size_t DataStream::readVlqI(int64_t& i) { + size_t bytesRead = Star::readVlqI(i, makeFunctionInputIterator([this]() { return this->read<uint8_t>(); })); + + if (bytesRead == NPos) + throw DataStreamException("Error reading VLQ encoded intenger!"); + + return bytesRead; +} + +size_t DataStream::readVlqS(size_t& i) { + uint64_t i64; + size_t res = readVlqU(i64); + if (i64 == 0) + i = NPos; + else + i = (size_t)(i64 - 1); + return res; +} + +uint64_t DataStream::readVlqU() { + uint64_t i; + readVlqU(i); + return i; +} + +int64_t DataStream::readVlqI() { + int64_t i; + readVlqI(i); + return i; +} + +size_t DataStream::readVlqS() { + size_t i; + readVlqS(i); + return i; +} + +DataStream& DataStream::operator<<(char const* s) { + writeStringData(s, strlen(s)); + return *this; +} + +DataStream& DataStream::operator<<(std::string const& d) { + writeStringData(d.c_str(), d.size()); + return *this; +} + +DataStream& DataStream::operator<<(const ByteArray& d) { + writeVlqU(d.size()); + writeData(d.ptr(), d.size()); + return *this; +} + +DataStream& DataStream::operator<<(const String& s) { + writeStringData(s.utf8Ptr(), s.utf8Size()); + return *this; +} + +DataStream& DataStream::operator>>(std::string& d) { + if (m_nullTerminatedStrings) { + d.clear(); + char c; + while (true) { + readData((char*)&c, sizeof(c)); + if (c == '\0') + break; + d.push_back(c); + } + } else { + d.resize((size_t)readVlqU()); + readData(&d[0], d.size()); + } + return *this; +} + +DataStream& DataStream::operator>>(ByteArray& d) { + d.resize((size_t)readVlqU()); + readData(d.ptr(), d.size()); + return *this; +} + +DataStream& DataStream::operator>>(String& s) { + std::string string; + operator>>(string); + s = move(string); + return *this; +} + +void DataStream::writeStringData(char const* data, size_t len) { + if (m_nullTerminatedStrings) { + writeData(data, len); + operator<<((uint8_t)0x00); + } else { + writeVlqU(len); + writeData(data, len); + } +} + +} diff --git a/source/core/StarDataStream.hpp b/source/core/StarDataStream.hpp new file mode 100644 index 0000000..2508a00 --- /dev/null +++ b/source/core/StarDataStream.hpp @@ -0,0 +1,393 @@ +#ifndef STAR_DATA_STREAM_HPP +#define STAR_DATA_STREAM_HPP + +#include "StarString.hpp" + +namespace Star { + +STAR_EXCEPTION(DataStreamException, IOException); + +// Writes complex types to bytes in a portable big-endian fashion. +class DataStream { +public: + DataStream(); + virtual ~DataStream() = default; + + static unsigned const CurrentStreamVersion = 1; + + // DataStream defaults to big-endian order for all primitive types + ByteOrder byteOrder() const; + void setByteOrder(ByteOrder byteOrder); + + // DataStream can optionally write strings as null terminated rather than + // length prefixed + bool nullTerminatedStrings() const; + void setNullTerminatedStrings(bool nullTerminatedStrings); + + // streamCompatibilityVersion defaults to CurrentStreamVersion, but can be + // changed for compatibility with older versions of DataStream serialization. + unsigned streamCompatibilityVersion() const; + void setStreamCompatibilityVersion(unsigned streamCompatibilityVersion); + + // Do direct reads and writes + virtual void readData(char* data, size_t len) = 0; + virtual void writeData(char const* data, size_t len) = 0; + + // These do not read / write sizes, they simply read / write directly. + ByteArray readBytes(size_t len); + void writeBytes(ByteArray const& ba); + + DataStream& operator<<(bool d); + DataStream& operator<<(char c); + DataStream& operator<<(int8_t d); + DataStream& operator<<(uint8_t d); + DataStream& operator<<(int16_t d); + DataStream& operator<<(uint16_t d); + DataStream& operator<<(int32_t d); + DataStream& operator<<(uint32_t d); + DataStream& operator<<(int64_t d); + DataStream& operator<<(uint64_t d); + DataStream& operator<<(float d); + DataStream& operator<<(double d); + + DataStream& operator>>(bool& d); + DataStream& operator>>(char& c); + DataStream& operator>>(int8_t& d); + DataStream& operator>>(uint8_t& d); + DataStream& operator>>(int16_t& d); + DataStream& operator>>(uint16_t& d); + DataStream& operator>>(int32_t& d); + DataStream& operator>>(uint32_t& d); + DataStream& operator>>(int64_t& d); + DataStream& operator>>(uint64_t& d); + DataStream& operator>>(float& d); + DataStream& operator>>(double& d); + + // Writes and reads a VLQ encoded integer. Can write / read anywhere from 1 + // to 10 bytes of data, with integers of smaller (absolute) value taking up + // fewer bytes. size_t version can be used to portably write a size_t type, + // and portably and efficiently handles the case of NPos. + + size_t writeVlqU(uint64_t i); + size_t writeVlqI(int64_t i); + size_t writeVlqS(size_t i); + + size_t readVlqU(uint64_t& i); + size_t readVlqI(int64_t& i); + size_t readVlqS(size_t& i); + + uint64_t readVlqU(); + int64_t readVlqI(); + size_t readVlqS(); + + // The following functions write / read data with length and then content + // following, but note that the length is encoded as an unsigned VLQ integer. + // String objects are encoded in utf8, and can optionally be written as null + // terminated rather than length then content. + + DataStream& operator<<(const char* s); + DataStream& operator<<(std::string const& d); + DataStream& operator<<(ByteArray const& d); + DataStream& operator<<(String const& s); + + DataStream& operator>>(std::string& d); + DataStream& operator>>(ByteArray& d); + DataStream& operator>>(String& s); + + // All enum types are automatically serializable + + template <typename EnumType, typename = typename std::enable_if<std::is_enum<EnumType>::value>::type> + DataStream& operator<<(EnumType const& e); + + template <typename EnumType, typename = typename std::enable_if<std::is_enum<EnumType>::value>::type> + DataStream& operator>>(EnumType& e); + + // Convenience method to avoid temporary. + template <typename T> + T read(); + + // Convenient argument style reading / writing + + template <typename Data> + void read(Data& data); + + template <typename Data> + void write(Data const& data); + + // Argument style reading / writing with casting. + + template <typename ReadType, typename Data> + void cread(Data& data); + + template <typename WriteType, typename Data> + void cwrite(Data const& data); + + // Argument style reading / writing of variable length integers. Arguments + // are explicitly casted, so things like enums are allowed. + + template <typename IntegralType> + void vuread(IntegralType& data); + + template <typename IntegralType> + void viread(IntegralType& data); + + template <typename IntegralType> + void vsread(IntegralType& data); + + template <typename IntegralType> + void vuwrite(IntegralType const& data); + + template <typename IntegralType> + void viwrite(IntegralType const& data); + + template <typename IntegralType> + void vswrite(IntegralType const& data); + + // Store a fixed point number as a variable length integer + + template <typename FloatType> + void vfread(FloatType& data, FloatType base); + + template <typename FloatType> + void vfwrite(FloatType const& data, FloatType base); + + // Read a shared / unique ptr, and store whether the pointer is initialized. + + template <typename PointerType, typename ReadFunction> + void pread(PointerType& pointer, ReadFunction readFunction); + + template <typename PointerType, typename WriteFunction> + void pwrite(PointerType const& pointer, WriteFunction writeFunction); + + template <typename PointerType> + void pread(PointerType& pointer); + + template <typename PointerType> + void pwrite(PointerType const& pointer); + + // WriteFunction should be void (DataStream& ds, Element const& e) + template <typename Container, typename WriteFunction> + void writeContainer(Container const& container, WriteFunction function); + + // ReadFunction should be void (DataStream& ds, Element& e) + template <typename Container, typename ReadFunction> + void readContainer(Container& container, ReadFunction function); + + template <typename Container, typename WriteFunction> + void writeMapContainer(Container& map, WriteFunction function); + + // Specialization of readContainer for map types (whose elements are a pair + // with the key type marked const) + template <typename Container, typename ReadFunction> + void readMapContainer(Container& map, ReadFunction function); + + template <typename Container> + void writeContainer(Container const& container); + + template <typename Container> + void readContainer(Container& container); + + template <typename Container> + void writeMapContainer(Container const& container); + + template <typename Container> + void readMapContainer(Container& container); + +private: + void writeStringData(char const* data, size_t len); + + ByteOrder m_byteOrder; + bool m_nullTerminatedStrings; + unsigned m_streamCompatibilityVersion; +}; + +template <typename EnumType, typename> +DataStream& DataStream::operator<<(EnumType const& e) { + *this << (typename std::underlying_type<EnumType>::type)e; + return *this; +} + +template <typename EnumType, typename> +DataStream& DataStream::operator>>(EnumType& e) { + typename std::underlying_type<EnumType>::type i; + *this >> i; + e = (EnumType)i; + return *this; +} + +template <typename T> +T DataStream::read() { + T t; + *this >> t; + return t; +} + +template <typename Data> +void DataStream::read(Data& data) { + *this >> data; +} + +template <typename Data> +void DataStream::write(Data const& data) { + *this << data; +} + +template <typename ReadType, typename Data> +void DataStream::cread(Data& data) { + ReadType v; + *this >> v; + data = (Data)v; +} + +template <typename WriteType, typename Data> +void DataStream::cwrite(Data const& data) { + WriteType v = (WriteType)data; + *this << v; +} + +template <typename IntegralType> +void DataStream::vuread(IntegralType& data) { + uint64_t i = readVlqU(); + data = (IntegralType)i; +} + +template <typename IntegralType> +void DataStream::viread(IntegralType& data) { + int64_t i = readVlqI(); + data = (IntegralType)i; +} + +template <typename IntegralType> +void DataStream::vsread(IntegralType& data) { + size_t s = readVlqS(); + data = (IntegralType)s; +} + +template <typename IntegralType> +void DataStream::vuwrite(IntegralType const& data) { + writeVlqU((uint64_t)data); +} + +template <typename IntegralType> +void DataStream::viwrite(IntegralType const& data) { + writeVlqI((int64_t)data); +} + +template <typename IntegralType> +void DataStream::vswrite(IntegralType const& data) { + writeVlqS((size_t)data); +} + +template <typename FloatType> +void DataStream::vfread(FloatType& data, FloatType base) { + int64_t i = readVlqI(); + data = (FloatType)i * base; +} + +template <typename FloatType> +void DataStream::vfwrite(FloatType const& data, FloatType base) { + writeVlqI((int64_t)round(data / base)); +} + +template <typename PointerType, typename ReadFunction> +void DataStream::pread(PointerType& pointer, ReadFunction readFunction) { + bool initialized = read<bool>(); + if (initialized) { + auto element = make_unique<typename std::decay<typename PointerType::element_type>::type>(); + readFunction(*this, *element); + pointer.reset(element.release()); + } else { + pointer.reset(); + } +} + +template <typename PointerType, typename WriteFunction> +void DataStream::pwrite(PointerType const& pointer, WriteFunction writeFunction) { + if (pointer) { + write(true); + writeFunction(*this, *pointer); + } else { + write(false); + } +} + +template <typename PointerType> +void DataStream::pread(PointerType& pointer) { + return pread(pointer, [](DataStream& ds, typename std::decay<typename PointerType::element_type>::type& value) { + ds.read(value); + }); +} + +template <typename PointerType> +void DataStream::pwrite(PointerType const& pointer) { + return pwrite(pointer, [](DataStream& ds, typename std::decay<typename PointerType::element_type>::type const& value) { + ds.write(value); + }); +} + +template <typename Container, typename WriteFunction> +void DataStream::writeContainer(Container const& container, WriteFunction function) { + writeVlqU(container.size()); + for (auto const& elem : container) + function(*this, elem); +} + +template <typename Container, typename ReadFunction> +void DataStream::readContainer(Container& container, ReadFunction function) { + container.clear(); + size_t size = readVlqU(); + for (size_t i = 0; i < size; ++i) { + typename Container::value_type elem; + function(*this, elem); + container.insert(container.end(), elem); + } +} + +template <typename Container, typename WriteFunction> +void DataStream::writeMapContainer(Container& map, WriteFunction function) { + writeVlqU(map.size()); + for (auto const& elem : map) + function(*this, elem.first, elem.second); +} + +template <typename Container, typename ReadFunction> +void DataStream::readMapContainer(Container& map, ReadFunction function) { + map.clear(); + size_t size = readVlqU(); + for (size_t i = 0; i < size; ++i) { + typename Container::key_type key; + typename Container::mapped_type mapped; + function(*this, key, mapped); + map.insert(make_pair(move(key), move(mapped))); + } +} + +template <typename Container> +void DataStream::writeContainer(Container const& container) { + writeContainer(container, [](DataStream& ds, typename Container::value_type const& element) { ds << element; }); +} + +template <typename Container> +void DataStream::readContainer(Container& container) { + readContainer(container, [](DataStream& ds, typename Container::value_type& element) { ds >> element; }); +} + +template <typename Container> +void DataStream::writeMapContainer(Container const& container) { + writeMapContainer(container, [](DataStream& ds, typename Container::key_type const& key, typename Container::mapped_type const& mapped) { + ds << key; + ds << mapped; + }); +} + +template <typename Container> +void DataStream::readMapContainer(Container& container) { + readMapContainer(container, [](DataStream& ds, typename Container::key_type& key, typename Container::mapped_type& mapped) { + ds >> key; + ds >> mapped; + }); +} + +} + +#endif diff --git a/source/core/StarDataStreamDevices.cpp b/source/core/StarDataStreamDevices.cpp new file mode 100644 index 0000000..0c37d8d --- /dev/null +++ b/source/core/StarDataStreamDevices.cpp @@ -0,0 +1,167 @@ +#include "StarDataStreamDevices.hpp" + +namespace Star { + +DataStreamFunctions::DataStreamFunctions(function<size_t(char*, size_t)> reader, function<size_t(char const*, size_t)> writer) + : m_reader(move(reader)), m_writer(move(writer)) {} + +void DataStreamFunctions::readData(char* data, size_t len) { + if (!m_reader) + throw DataStreamException("DataStreamFunctions no read function given"); + m_reader(data, len); +} + +void DataStreamFunctions::writeData(char const* data, size_t len) { + if (!m_writer) + throw DataStreamException("DataStreamFunctions no write function given"); + m_writer(data, len); +} + +DataStreamIODevice::DataStreamIODevice(IODevicePtr device) + : m_device(move(device)) {} + +IODevicePtr const& DataStreamIODevice::device() const { + return m_device; +} + +void DataStreamIODevice::seek(size_t pos, IOSeek mode) { + m_device->seek(pos, mode); +} + +bool DataStreamIODevice::atEnd() { + return m_device->atEnd(); +} + +StreamOffset DataStreamIODevice::pos() { + return m_device->pos(); +} + +void DataStreamIODevice::readData(char* data, size_t len) { + return m_device->readFull(data, len); +} + +void DataStreamIODevice::writeData(char const* data, size_t len) { + return m_device->writeFull(data, len); +} + +DataStreamBuffer::DataStreamBuffer() { + m_buffer = make_shared<Buffer>(); +} + +DataStreamBuffer::DataStreamBuffer(size_t s) + : DataStreamBuffer() { + reset(s); +} + +DataStreamBuffer::DataStreamBuffer(ByteArray b) + : DataStreamBuffer() { + reset(std::move(b)); +} + +void DataStreamBuffer::resize(size_t size) { + m_buffer->resize(size); +} + +void DataStreamBuffer::reserve(size_t size) { + m_buffer->reserve(size); +} + +void DataStreamBuffer::clear() { + m_buffer->clear(); +} + +BufferPtr const& DataStreamBuffer::device() const { + return m_buffer; +} + +ByteArray& DataStreamBuffer::data() { + return m_buffer->data(); +} + +ByteArray const& DataStreamBuffer::data() const { + return m_buffer->data(); +} + +ByteArray DataStreamBuffer::takeData() { + return m_buffer->takeData(); +} + +char* DataStreamBuffer::ptr() { + return m_buffer->ptr(); +} + +const char* DataStreamBuffer::ptr() const { + return m_buffer->ptr(); +} + +size_t DataStreamBuffer::size() const { + return m_buffer->dataSize(); +} + +bool DataStreamBuffer::empty() const { + return m_buffer->empty(); +} + +void DataStreamBuffer::seek(size_t pos, IOSeek mode) { + m_buffer->seek(pos, mode); +} + +bool DataStreamBuffer::atEnd() { + return m_buffer->atEnd(); +} + +size_t DataStreamBuffer::pos() { + return (size_t)m_buffer->pos(); +} + +void DataStreamBuffer::reset(size_t newSize) { + m_buffer->reset(newSize); +} + +void DataStreamBuffer::reset(ByteArray b) { + m_buffer->reset(move(b)); +} + +void DataStreamBuffer::readData(char* data, size_t len) { + m_buffer->readFull(data, len); +} + +void DataStreamBuffer::writeData(char const* data, size_t len) { + m_buffer->writeFull(data, len); +} + +DataStreamExternalBuffer::DataStreamExternalBuffer() {} + +DataStreamExternalBuffer::DataStreamExternalBuffer(char const* externalData, size_t len) : DataStreamExternalBuffer() { + reset(externalData, len); +} + +char const* DataStreamExternalBuffer::ptr() const { + return m_buffer.ptr(); +} + +size_t DataStreamExternalBuffer::size() const { + return m_buffer.dataSize(); +} + +bool DataStreamExternalBuffer::empty() const { + return m_buffer.empty(); +} + +void DataStreamExternalBuffer::seek(size_t pos, IOSeek mode) { + m_buffer.seek(pos, mode); +} + +bool DataStreamExternalBuffer::atEnd() { + return m_buffer.atEnd(); +} + +size_t DataStreamExternalBuffer::pos() { + return m_buffer.pos(); +} + +void DataStreamExternalBuffer::reset(char const* externalData, size_t len) { + m_buffer.reset(externalData, len); +} + +} diff --git a/source/core/StarDataStreamDevices.hpp b/source/core/StarDataStreamDevices.hpp new file mode 100644 index 0000000..88ee0e9 --- /dev/null +++ b/source/core/StarDataStreamDevices.hpp @@ -0,0 +1,249 @@ +#ifndef STAR_DATA_STREAM_BUFFER_HPP +#define STAR_DATA_STREAM_BUFFER_HPP + +#include "StarBuffer.hpp" +#include "StarDataStream.hpp" + +namespace Star { + +// Implements DataStream using function objects as implementations of read/write. +class DataStreamFunctions : public DataStream { +public: + // Either reader or writer can be unset, if unset then the readData/writeData + // implementations will throw DataStreamException as unimplemented. + DataStreamFunctions(function<size_t(char*, size_t)> reader, function<size_t(char const*, size_t)> writer); + + void readData(char* data, size_t len) override; + void writeData(char const* data, size_t len) override; + +private: + function<size_t(char*, size_t)> m_reader; + function<size_t(char const*, size_t)> m_writer; +}; + +class DataStreamIODevice : public DataStream { +public: + DataStreamIODevice(IODevicePtr device); + + IODevicePtr const& device() const; + + void seek(size_t pos, IOSeek seek = IOSeek::Absolute); + bool atEnd(); + StreamOffset pos(); + + void readData(char* data, size_t len) override; + void writeData(char const* data, size_t len) override; + +private: + IODevicePtr m_device; +}; + +class DataStreamBuffer : public DataStream { +public: + // Convenience methods to serialize to / from ByteArray directly without + // having to construct a temporary DataStreamBuffer to do it + + template <typename T> + static ByteArray serialize(T const& t); + + template <typename T> + static ByteArray serializeContainer(T const& t); + + template <typename T, typename WriteFunction> + static ByteArray serializeContainer(T const& t, WriteFunction writeFunction); + + template <typename T> + static ByteArray serializeMapContainer(T const& t); + + template <typename T, typename WriteFunction> + static ByteArray serializeMapContainer(T const& t, WriteFunction writeFunction); + + template <typename T> + static void deserialize(T& t, ByteArray data); + + template <typename T> + static void deserializeContainer(T& t, ByteArray data); + + template <typename T, typename ReadFunction> + static void deserializeContainer(T& t, ByteArray data, ReadFunction readFunction); + + template <typename T> + static void deserializeMapContainer(T& t, ByteArray data); + + template <typename T, typename ReadFunction> + static void deserializeMapContainer(T& t, ByteArray data, ReadFunction readFunction); + + template <typename T> + static T deserialize(ByteArray data); + + template <typename T> + static T deserializeContainer(ByteArray data); + + template <typename T, typename ReadFunction> + static T deserializeContainer(ByteArray data, ReadFunction readFunction); + + template <typename T> + static T deserializeMapContainer(ByteArray data); + + template <typename T, typename ReadFunction> + static T deserializeMapContainer(ByteArray data, ReadFunction readFunction); + + DataStreamBuffer(); + DataStreamBuffer(size_t initialSize); + DataStreamBuffer(ByteArray b); + + // Resize existing buffer to new size. + void resize(size_t size); + void reserve(size_t size); + void clear(); + + ByteArray& data(); + ByteArray const& data() const; + ByteArray takeData(); + + char* ptr(); + char const* ptr() const; + + BufferPtr const& device() const; + + size_t size() const; + bool empty() const; + + void seek(size_t pos, IOSeek seek = IOSeek::Absolute); + bool atEnd(); + size_t pos(); + + // Set new buffer. + void reset(size_t newSize); + void reset(ByteArray b); + + void readData(char* data, size_t len) override; + void writeData(char const* data, size_t len) override; + +private: + BufferPtr m_buffer; +}; + +class DataStreamExternalBuffer : public DataStream { +public: + DataStreamExternalBuffer(); + DataStreamExternalBuffer(char const* externalData, size_t len); + + char const* ptr() const; + + size_t size() const; + bool empty() const; + + void seek(size_t pos, IOSeek mode = IOSeek::Absolute); + bool atEnd(); + size_t pos(); + + void reset(char const* externalData, size_t len); + +private: + ExternalBuffer m_buffer; +}; + +template <typename T> +ByteArray DataStreamBuffer::serialize(T const& t) { + DataStreamBuffer ds; + ds.write(t); + return ds.takeData(); +} + +template <typename T> +ByteArray DataStreamBuffer::serializeContainer(T const& t) { + DataStreamBuffer ds; + ds.writeContainer(t); + return ds.takeData(); +} + +template <typename T, typename WriteFunction> +ByteArray DataStreamBuffer::serializeContainer(T const& t, WriteFunction writeFunction) { + DataStreamBuffer ds; + ds.writeContainer(t, writeFunction); + return ds.takeData(); +} + +template <typename T> +ByteArray DataStreamBuffer::serializeMapContainer(T const& t) { + DataStreamBuffer ds; + ds.writeMapContainer(t); + return ds.takeData(); +} + +template <typename T, typename WriteFunction> +ByteArray DataStreamBuffer::serializeMapContainer(T const& t, WriteFunction writeFunction) { + DataStreamBuffer ds; + ds.writeMapContainer(t, writeFunction); + return ds.takeData(); +} + +template <typename T> +void DataStreamBuffer::deserialize(T& t, ByteArray data) { + DataStreamBuffer ds(move(data)); + ds.read(t); +} + +template <typename T> +void DataStreamBuffer::deserializeContainer(T& t, ByteArray data) { + DataStreamBuffer ds(move(data)); + ds.readContainer(t); +} + +template <typename T, typename ReadFunction> +void DataStreamBuffer::deserializeContainer(T& t, ByteArray data, ReadFunction readFunction) { + DataStreamBuffer ds(move(data)); + ds.readContainer(t, readFunction); +} + +template <typename T> +void DataStreamBuffer::deserializeMapContainer(T& t, ByteArray data) { + DataStreamBuffer ds(move(data)); + ds.readMapContainer(t); +} + +template <typename T, typename ReadFunction> +void DataStreamBuffer::deserializeMapContainer(T& t, ByteArray data, ReadFunction readFunction) { + DataStreamBuffer ds(move(data)); + ds.readMapContainer(t, readFunction); +} + +template <typename T> +T DataStreamBuffer::deserialize(ByteArray data) { + T t; + deserialize(t, move(data)); + return t; +} + +template <typename T> +T DataStreamBuffer::deserializeContainer(ByteArray data) { + T t; + deserializeContainer(t, move(data)); + return t; +} + +template <typename T, typename ReadFunction> +T DataStreamBuffer::deserializeContainer(ByteArray data, ReadFunction readFunction) { + T t; + deserializeContainer(t, move(data), readFunction); + return t; +} + +template <typename T> +T DataStreamBuffer::deserializeMapContainer(ByteArray data) { + T t; + deserializeMapContainer(t, move(data)); + return t; +} + +template <typename T, typename ReadFunction> +T DataStreamBuffer::deserializeMapContainer(ByteArray data, ReadFunction readFunction) { + T t; + deserializeMapContainer(t, move(data), readFunction); + return t; +} + +} + +#endif diff --git a/source/core/StarDataStreamExtra.hpp b/source/core/StarDataStreamExtra.hpp new file mode 100644 index 0000000..c0965a9 --- /dev/null +++ b/source/core/StarDataStreamExtra.hpp @@ -0,0 +1,393 @@ +#ifndef STAR_DATA_STREAM_EXTRA_HPP +#define STAR_DATA_STREAM_EXTRA_HPP + +#include "StarDataStream.hpp" +#include "StarMultiArray.hpp" +#include "StarColor.hpp" +#include "StarPoly.hpp" +#include "StarMaybe.hpp" +#include "StarEither.hpp" +#include "StarOrderedMap.hpp" +#include "StarOrderedSet.hpp" + +namespace Star { + +struct DataStreamWriteFunctor { + DataStreamWriteFunctor(DataStream& ds) : ds(ds) {} + + DataStream& ds; + template <typename T> + void operator()(T const& t) const { + ds << t; + } +}; + +struct DataStreamReadFunctor { + DataStreamReadFunctor(DataStream& ds) : ds(ds) {} + + DataStream& ds; + template <typename T> + void operator()(T& t) const { + ds >> t; + } +}; + +inline DataStream& operator<<(DataStream& ds, Empty const&) { + return ds; +} + +inline DataStream& operator>>(DataStream& ds, Empty&) { + return ds; +} + +template <typename ElementT, size_t SizeN> +DataStream& operator<<(DataStream& ds, Array<ElementT, SizeN> const& array) { + for (size_t i = 0; i < SizeN; ++i) + ds << array[i]; + return ds; +} + +template <typename ElementT, size_t SizeN> +DataStream& operator>>(DataStream& ds, Array<ElementT, SizeN>& array) { + for (size_t i = 0; i < SizeN; ++i) + ds >> array[i]; + return ds; +} + +template <typename ElementT, size_t RankN> +DataStream& operator<<(DataStream& ds, MultiArray<ElementT, RankN> const& array) { + auto size = array.size(); + for (size_t i = 0; i < RankN; ++i) + ds.writeVlqU(size[i]); + + size_t count = array.count(); + for (size_t i = 0; i < count; ++i) + ds << array.atIndex(i); + + return ds; +} + +template <typename ElementT, size_t RankN> +DataStream& operator>>(DataStream& ds, MultiArray<ElementT, RankN>& array) { + typename MultiArray<ElementT, RankN>::SizeList size; + for (size_t i = 0; i < RankN; ++i) + size[i] = ds.readVlqU(); + + array.setSize(size); + size_t count = array.count(); + for (size_t i = 0; i < count; ++i) + ds >> array.atIndex(i); + + return ds; +} + +inline DataStream& operator<<(DataStream& ds, Color const& color) { + ds << color.toRgbaF(); + return ds; +} + +inline DataStream& operator>>(DataStream& ds, Color& color) { + color = Color::rgbaf(ds.read<Vec4F>()); + return ds; +} + +template <typename First, typename Second> +DataStream& operator<<(DataStream& ds, pair<First, Second> const& pair) { + ds << pair.first; + ds << pair.second; + return ds; +} + +template <typename First, typename Second> +DataStream& operator>>(DataStream& ds, pair<First, Second>& pair) { + ds >> pair.first; + ds >> pair.second; + return ds; +} + +template <typename Element> +DataStream& operator<<(DataStream& ds, std::shared_ptr<Element> const& ptr) { + ds.pwrite(ptr); + return ds; +} + +template <typename Element> +DataStream& operator>>(DataStream& ds, std::shared_ptr<Element>& ptr) { + ds.pread(ptr); + return ds; +} + +template <typename BaseList> +DataStream& operator<<(DataStream& ds, ListMixin<BaseList> const& list) { + ds.writeContainer(list); + return ds; +} + +template <typename BaseList> +DataStream& operator>>(DataStream& ds, ListMixin<BaseList>& list) { + ds.readContainer(list); + return ds; +} + +template <typename BaseSet> +DataStream& operator<<(DataStream& ds, SetMixin<BaseSet> const& set) { + ds.writeContainer(set); + return ds; +} + +template <typename BaseSet> +DataStream& operator>>(DataStream& ds, SetMixin<BaseSet>& set) { + ds.readContainer(set); + return ds; +} + +template <typename BaseMap> +DataStream& operator<<(DataStream& ds, MapMixin<BaseMap> const& map) { + ds.writeMapContainer(map); + return ds; +} + +template <typename BaseMap> +DataStream& operator>>(DataStream& ds, MapMixin<BaseMap>& map) { + ds.readMapContainer(map); + return ds; +} + +template <typename Key, typename Value, typename Compare, typename Allocator> +DataStream& operator>>(DataStream& ds, OrderedMap<Key, Value, Compare, Allocator>& map) { + ds.readMapContainer(map); + return ds; +} + +template <typename Key, typename Value, typename Compare, typename Allocator> +DataStream& operator<<(DataStream& ds, OrderedMap<Key, Value, Compare, Allocator> const& map) { + ds.writeMapContainer(map); + return ds; +} + +template <typename Key, typename Value, typename Hash, typename Equals, typename Allocator> +DataStream& operator>>(DataStream& ds, OrderedHashMap<Key, Value, Hash, Equals, Allocator>& map) { + ds.readMapContainer(map); + return ds; +} + +template <typename Key, typename Value, typename Hash, typename Equals, typename Allocator> +DataStream& operator<<(DataStream& ds, OrderedHashMap<Key, Value, Hash, Equals, Allocator> const& map) { + ds.writeMapContainer(map); + return ds; +} + +template <typename Value, typename Compare, typename Allocator> +DataStream& operator>>(DataStream& ds, OrderedSet<Value, Compare, Allocator>& set) { + ds.readContainer(set); + return ds; +} + +template <typename Value, typename Compare, typename Allocator> +DataStream& operator<<(DataStream& ds, OrderedSet<Value, Compare, Allocator> const& set) { + ds.writeContainer(set); + return ds; +} + +template <typename Value, typename Hash, typename Equals, typename Allocator> +DataStream& operator>>(DataStream& ds, OrderedHashSet<Value, Hash, Equals, Allocator>& set) { + ds.readContainer(set); + return ds; +} + +template <typename Value, typename Hash, typename Equals, typename Allocator> +DataStream& operator<<(DataStream& ds, OrderedHashSet<Value, Hash, Equals, Allocator> const& set) { + ds.writeContainer(set); + return ds; +} + +template <typename DataT> +DataStream& operator<<(DataStream& ds, Polygon<DataT> const& poly) { + ds.writeContainer(poly.vertexes()); + return ds; +} + +template <typename DataT> +DataStream& operator>>(DataStream& ds, Polygon<DataT>& poly) { + ds.readContainer(poly.vertexes()); + return ds; +} + +template <typename DataT, size_t Dimensions> +DataStream& operator<<(DataStream& ds, Box<DataT, Dimensions> const& box) { + ds.write(box.min()); + ds.write(box.max()); + return ds; +} + +template <typename DataT, size_t Dimensions> +DataStream& operator>>(DataStream& ds, Box<DataT, Dimensions>& box) { + ds.read(box.min()); + ds.read(box.max()); + return ds; +} + +template <typename DataT> +DataStream& operator<<(DataStream& ds, Matrix3<DataT> const& mat3) { + ds.write(mat3[0]); + ds.write(mat3[1]); + ds.write(mat3[2]); + return ds; +} + +template <typename DataT> +DataStream& operator>>(DataStream& ds, Matrix3<DataT>& mat3) { + ds.read(mat3[0]); + ds.read(mat3[1]); + ds.read(mat3[2]); + return ds; +} + +// Writes / reads a Variant type if every type has operator<< / operator>> +// defined for DataStream and if it is default constructible. + +template <typename FirstType, typename... RestTypes> +DataStream& operator<<(DataStream& ds, Variant<FirstType, RestTypes...> const& variant) { + ds.write<VariantTypeIndex>(variant.typeIndex()); + variant.call(DataStreamWriteFunctor{ds}); + return ds; +} + +template <typename FirstType, typename... RestTypes> +DataStream& operator>>(DataStream& ds, Variant<FirstType, RestTypes...>& variant) { + variant.makeType(ds.read<VariantTypeIndex>()); + variant.call(DataStreamReadFunctor{ds}); + return ds; +} + +template <typename... AllowedTypes> +DataStream& operator<<(DataStream& ds, MVariant<AllowedTypes...> const& mvariant) { + ds.write<VariantTypeIndex>(mvariant.typeIndex()); + mvariant.call(DataStreamWriteFunctor{ds}); + return ds; +} + +template <typename... AllowedTypes> +DataStream& operator>>(DataStream& ds, MVariant<AllowedTypes...>& mvariant) { + mvariant.makeType(ds.read<VariantTypeIndex>()); + mvariant.call(DataStreamReadFunctor{ds}); + return ds; +} + +// Writes / reads a Maybe type if the underlying type has operator<< / +// operator>> defined for DataStream + +template <typename T, typename WriteFunction> +void writeMaybe(DataStream& ds, Maybe<T> const& maybe, WriteFunction&& writeFunction) { + if (maybe) { + ds.write<bool>(true); + writeFunction(ds, *maybe); + } else { + ds.write<bool>(false); + } +} + +template <typename T, typename ReadFunction> +void readMaybe(DataStream& ds, Maybe<T>& maybe, ReadFunction&& readFunction) { + bool set = ds.read<bool>(); + if (set) { + T t; + readFunction(ds, t); + maybe = move(t); + } else { + maybe.reset(); + } +} + +template <typename T> +DataStream& operator<<(DataStream& ds, Maybe<T> const& maybe) { + writeMaybe(ds, maybe, [](DataStream& ds, T const& t) { ds << t; }); + return ds; +} + +template <typename T> +DataStream& operator>>(DataStream& ds, Maybe<T>& maybe) { + readMaybe(ds, maybe, [](DataStream& ds, T& t) { ds >> t; }); + return ds; +} + +// Writes / reads an Either type, an Either can either have a left or right +// value, or in edge cases, nothing. + +template <typename Left, typename Right> +DataStream& operator<<(DataStream& ds, Either<Left, Right> const& either) { + if (either.isLeft()) { + ds.write<uint8_t>(1); + ds.write(either.left()); + } else if (either.isRight()) { + ds.write<uint8_t>(2); + ds.write(either.right()); + } else { + ds.write<uint8_t>(0); + } + return ds; +} + +template <typename Left, typename Right> +DataStream& operator>>(DataStream& ds, Either<Left, Right>& either) { + uint8_t m = ds.read<uint8_t>(); + if (m == 1) + either = makeLeft(ds.read<Left>()); + else if (m == 2) + either = makeRight(ds.read<Right>()); + return ds; +} + +template <typename DataT, size_t Dimensions> +DataStream& operator<<(DataStream& ds, Line<DataT, Dimensions> const& line) { + ds.write(line.min()); + ds.write(line.max()); + return ds; +} + +template <typename DataT, size_t Dimensions> +DataStream& operator>>(DataStream& ds, Line<DataT, Dimensions>& line) { + ds.read(line.min()); + ds.read(line.max()); + return ds; +} + +template <typename T> +DataStream& operator<<(DataStream& ds, tuple<T> const& t) { + ds << get<0>(t); + return ds; +} + +struct DataStreamReader { + DataStream& ds; + + template <typename RT> + void operator()(RT& t) { + ds >> t; + } +}; + +struct DataStreamWriter { + DataStream& ds; + + template <typename RT> + void operator()(RT const& t) { + ds << t; + } +}; + +template <typename... T> +DataStream& operator>>(DataStream& ds, tuple<T...>& t) { + tupleCallFunction(t, DataStreamReader{ds}); + return ds; +} + +template <typename... T> +DataStream& operator<<(DataStream& ds, tuple<T...> const& t) { + tupleCallFunction(t, DataStreamWriter{ds}); + return ds; +} + +} + +#endif diff --git a/source/core/StarDynamicLib.hpp b/source/core/StarDynamicLib.hpp new file mode 100644 index 0000000..c479d8b --- /dev/null +++ b/source/core/StarDynamicLib.hpp @@ -0,0 +1,38 @@ +#ifndef STAR_PLATFORM_HPP +#define STAR_PLATFORM_HPP + +#include "StarString.hpp" + +namespace Star { + +STAR_CLASS(DynamicLib); + +class DynamicLib { +public: + // Returns the library extension normally used on the current platform + // including the '.', e.g. '.dll', '.so', '.dylib' + static String libraryExtension(); + + // Load a dll from the given filename. If the library is found and + // succesfully loaded, returns a handle to the library, otherwise nullptr. + static DynamicLibUPtr loadLibrary(String const& fileName); + + // Load a dll from the given name, minus extension. + static DynamicLibUPtr loadLibraryBase(String const& baseName); + + // Should return handle to currently running executable. Will always + // succeed. + static DynamicLibUPtr currentExecutable(); + + virtual ~DynamicLib() = default; + + virtual void* funcPtr(char const* name) = 0; +}; + +inline DynamicLibUPtr DynamicLib::loadLibraryBase(String const& baseName) { + return loadLibrary(baseName + libraryExtension()); +} + +} + +#endif diff --git a/source/core/StarDynamicLib_unix.cpp b/source/core/StarDynamicLib_unix.cpp new file mode 100644 index 0000000..dd81d1f --- /dev/null +++ b/source/core/StarDynamicLib_unix.cpp @@ -0,0 +1,46 @@ +#include "StarDynamicLib.hpp" + +#include <dlfcn.h> +#include <pthread.h> +#include <sys/time.h> +#include <errno.h> + +namespace Star { + +struct PrivateDynLib : public DynamicLib { + PrivateDynLib(void* handle) + : m_handle(handle) {} + + ~PrivateDynLib() { + dlclose(m_handle); + } + + void* funcPtr(const char* name) { + return dlsym(m_handle, name); + } + + void* m_handle; +}; + +String DynamicLib::libraryExtension() { +#ifdef STAR_SYSTEM_MACOS + return ".dylib"; +#else + return ".so"; +#endif +} + +DynamicLibUPtr DynamicLib::loadLibrary(String const& libraryName) { + void* handle = dlopen(libraryName.utf8Ptr(), RTLD_NOW); + if (handle == NULL) + return {}; + return make_unique<PrivateDynLib>(handle); +} + +DynamicLibUPtr DynamicLib::currentExecutable() { + void* handle = dlopen(NULL, 0); + starAssert(handle); + return make_unique<PrivateDynLib>(handle); +} + +} diff --git a/source/core/StarDynamicLib_windows.cpp b/source/core/StarDynamicLib_windows.cpp new file mode 100644 index 0000000..41272ce --- /dev/null +++ b/source/core/StarDynamicLib_windows.cpp @@ -0,0 +1,43 @@ +#include "StarDynamicLib.hpp" +#include "StarFormat.hpp" +#include "StarString_windows.hpp" + +#include <windows.h> + +namespace Star { + +class PrivateDynLib : public DynamicLib { +public: + PrivateDynLib(void* handle) + : m_handle(handle) {} + + ~PrivateDynLib() { + FreeLibrary((HMODULE)m_handle); + } + + void* funcPtr(const char* name) { + return (void*)GetProcAddress((HMODULE)m_handle, name); + } + +private: + void* m_handle; +}; + +String DynamicLib::libraryExtension() { + return ".dll"; +} + +DynamicLibUPtr DynamicLib::loadLibrary(String const& libraryName) { + void* handle = LoadLibraryW(stringToUtf16(libraryName).get()); + if (handle == NULL) + return {}; + return make_unique<PrivateDynLib>(handle); +} + +DynamicLibUPtr DynamicLib::currentExecutable() { + void* handle = GetModuleHandle(0); + starAssert(handle); + return make_unique<PrivateDynLib>(handle); +} + +} diff --git a/source/core/StarEither.hpp b/source/core/StarEither.hpp new file mode 100644 index 0000000..287e4bc --- /dev/null +++ b/source/core/StarEither.hpp @@ -0,0 +1,243 @@ +#ifndef STAR_EITHER_HPP +#define STAR_EITHER_HPP + +#include "StarVariant.hpp" + +namespace Star { + +STAR_EXCEPTION(EitherException, StarException); + +template <typename Value> +struct EitherLeftValue { + Value value; +}; + +template <typename Value> +struct EitherRightValue { + Value value; +}; + +template <typename Value> +EitherLeftValue<Value> makeLeft(Value value); + +template <typename Value> +EitherRightValue<Value> makeRight(Value value); + +// Container that contains exactly one of either Left or Right. +template <typename Left, typename Right> +class Either { +public: + // Constructs Either that contains a default constructed Left value + Either(); + + Either(EitherLeftValue<Left> left); + Either(EitherRightValue<Right> right); + + template <typename T> + Either(EitherLeftValue<T> left); + + template <typename T> + Either(EitherRightValue<T> right); + + Either(Either const& rhs); + Either(Either&& rhs); + + Either& operator=(Either const& rhs); + Either& operator=(Either&& rhs); + + template <typename T> + Either& operator=(EitherLeftValue<T> left); + + template <typename T> + Either& operator=(EitherRightValue<T> right); + + bool isLeft() const; + bool isRight() const; + + void setLeft(Left left); + void setRight(Right left); + + // left() and right() throw EitherException on invalid access + + Left const& left() const; + Right const& right() const; + + Left& left(); + Right& right(); + + Maybe<Left> maybeLeft() const; + Maybe<Right> maybeRight() const; + + // leftPtr() and rightPtr() do not throw on invalid access + + Left const* leftPtr() const; + Right const* rightPtr() const; + + Left* leftPtr(); + Right* rightPtr(); + +private: + typedef EitherLeftValue<Left> LeftType; + typedef EitherRightValue<Right> RightType; + + Variant<LeftType, RightType> m_value; +}; + +template <typename Value> +EitherLeftValue<Value> makeLeft(Value value) { + return {move(value)}; +} + +template <typename Value> +EitherRightValue<Value> makeRight(Value value) { + return {move(value)}; +} + +template <typename Left, typename Right> +Either<Left, Right>::Either() {} + +template <typename Left, typename Right> +Either<Left, Right>::Either(EitherLeftValue<Left> left) + : m_value(move(left)) {} + +template <typename Left, typename Right> +Either<Left, Right>::Either(EitherRightValue<Right> right) + : m_value(move(right)) {} + +template <typename Left, typename Right> +template <typename T> +Either<Left, Right>::Either(EitherLeftValue<T> left) + : Either(LeftType{move(left.value)}) {} + +template <typename Left, typename Right> +template <typename T> +Either<Left, Right>::Either(EitherRightValue<T> right) + : Either(RightType{move(right.value)}) {} + +template <typename Left, typename Right> +Either<Left, Right>::Either(Either const& rhs) + : m_value(rhs.m_value) {} + +template <typename Left, typename Right> +Either<Left, Right>::Either(Either&& rhs) + : m_value(move(rhs.m_value)) {} + +template <typename Left, typename Right> +Either<Left, Right>& Either<Left, Right>::operator=(Either const& rhs) { + m_value = rhs.m_value; + return *this; +} + +template <typename Left, typename Right> +Either<Left, Right>& Either<Left, Right>::operator=(Either&& rhs) { + m_value = move(rhs.m_value); + return *this; +} + +template <typename Left, typename Right> +template <typename T> +Either<Left, Right>& Either<Left, Right>::operator=(EitherLeftValue<T> left) { + m_value = LeftType{move(left.value)}; + return *this; +} + +template <typename Left, typename Right> +template <typename T> +Either<Left, Right>& Either<Left, Right>::operator=(EitherRightValue<T> right) { + m_value = RightType{move(right.value)}; + return *this; +} + +template <typename Left, typename Right> +bool Either<Left, Right>::isLeft() const { + return m_value.template is<LeftType>(); +} + +template <typename Left, typename Right> +bool Either<Left, Right>::isRight() const { + return m_value.template is<RightType>(); +} + +template <typename Left, typename Right> +void Either<Left, Right>::setLeft(Left left) { + m_value = LeftType{move(left)}; +} + +template <typename Left, typename Right> +void Either<Left, Right>::setRight(Right right) { + m_value = RightType{move(right)}; +} + +template <typename Left, typename Right> +Left const& Either<Left, Right>::left() const { + if (auto l = leftPtr()) + return *l; + throw EitherException("Improper access of left side of Either"); +} + +template <typename Left, typename Right> +Right const& Either<Left, Right>::right() const { + if (auto r = rightPtr()) + return *r; + throw EitherException("Improper access of right side of Either"); +} + +template <typename Left, typename Right> +Left& Either<Left, Right>::left() { + if (auto l = leftPtr()) + return *l; + throw EitherException("Improper access of left side of Either"); +} + +template <typename Left, typename Right> +Right& Either<Left, Right>::right() { + if (auto r = rightPtr()) + return *r; + throw EitherException("Improper access of right side of Either"); +} + +template <typename Left, typename Right> +Maybe<Left> Either<Left, Right>::maybeLeft() const { + if (auto l = leftPtr()) + return *l; + return {}; +} + +template <typename Left, typename Right> +Maybe<Right> Either<Left, Right>::maybeRight() const { + if (auto r = rightPtr()) + return *r; + return {}; +} + +template <typename Left, typename Right> +Left const* Either<Left, Right>::leftPtr() const { + if (auto l = m_value.template ptr<LeftType>()) + return &l->value; + return nullptr; +} + +template <typename Left, typename Right> +Right const* Either<Left, Right>::rightPtr() const { + if (auto r = m_value.template ptr<RightType>()) + return &r->value; + return nullptr; +} + +template <typename Left, typename Right> +Left* Either<Left, Right>::leftPtr() { + if (auto l = m_value.template ptr<LeftType>()) + return &l->value; + return nullptr; +} + +template <typename Left, typename Right> +Right* Either<Left, Right>::rightPtr() { + if (auto r = m_value.template ptr<RightType>()) + return &r->value; + return nullptr; +} + +} + +#endif diff --git a/source/core/StarEncode.cpp b/source/core/StarEncode.cpp new file mode 100644 index 0000000..b3a8b79 --- /dev/null +++ b/source/core/StarEncode.cpp @@ -0,0 +1,226 @@ +#include "StarEncode.hpp" + +namespace Star { + +size_t hexEncode(char const* data, size_t len, char* output, size_t outLen) { + static char const hex[] = "0123456789abcdef"; + + len = std::min(len, outLen / 2); + for (size_t i = 0; i < len; ++i) { + output[i * 2] = hex[(data[i] & 0xf0) >> 4]; + output[i * 2 + 1] = hex[(data[i] & 0x0f)]; + } + + return len * 2; +} + +size_t hexDecode(char const* src, size_t len, char* output, size_t outLen) { + for (size_t i = 0; i < len / 2; ++i) { + if (i >= outLen) + return i; + + uint8_t b1 = 0; + char c1 = src[i * 2]; + if (c1 >= '0' && c1 <= '9') + b1 = c1 - '0'; + else if (c1 >= 'A' && c1 <= 'F') + b1 = c1 - 'A' + 10; + else if (c1 >= 'a' && c1 <= 'f') + b1 = c1 - 'a' + 10; + + uint8_t b2 = 0; + char c2 = src[i * 2 + 1]; + if (c2 >= '0' && c2 <= '9') + b2 = c2 - '0'; + else if (c2 >= 'A' && c2 <= 'F') + b2 = c2 - 'A' + 10; + else if (c2 >= 'a' && c2 <= 'f') + b2 = c2 - 'a' + 10; + + *output++ = (b1 << 4) | b2; + } + + return len / 2; +} + +size_t nibbleDecode(char const* src, size_t len, char* output, size_t outLen) { + for (size_t i = 0; i < len; ++i) { + if (i >= outLen) + return i; + + uint8_t b = 0; + char c = src[i]; + if (c >= '0' && c <= '9') + b = c - '0'; + else if (c >= 'A' && c <= 'F') + b = c - 'A' + 10; + else if (c >= 'a' && c <= 'f') + b = c - 'a' + 10; + + *output++ = b; + } + + return len; +} + +static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +size_t base64Encode(char const* data, size_t len, char* output, size_t outLen) { + if (outLen == 0) + return 0; + size_t written = 0; + + unsigned char ca3[3] = {0, 0, 0}; + unsigned char ca4[4] = {0, 0, 0, 0}; + const unsigned char* inPtr = (const unsigned char*)data; + int i = 0, j = 0, in_len = len; + + while (in_len--) { + ca3[i++] = *(inPtr++); + if (i == 3) { + ca4[0] = (ca3[0] & 0xfc) >> 2; + ca4[1] = ((ca3[0] & 0x03) << 4) + ((ca3[1] & 0xf0) >> 4); + ca4[2] = ((ca3[1] & 0x0f) << 2) + ((ca3[2] & 0xc0) >> 6); + ca4[3] = ca3[2] & 0x3f; + for (i = 0; (i < 4); i++) { + --outLen; + *output = base64_chars[ca4[i]]; + ++output; + ++written; + + if (outLen == 0) + return written; + } + i = 0; + } + } + + if (i) { + for (j = i; j < 3; j++) + ca3[j] = '\0'; + ca4[0] = (ca3[0] & 0xfc) >> 2; + ca4[1] = ((ca3[0] & 0x03) << 4) + ((ca3[1] & 0xf0) >> 4); + ca4[2] = ((ca3[1] & 0x0f) << 2) + ((ca3[2] & 0xc0) >> 6); + ca4[3] = ca3[2] & 0x3f; + for (j = 0; (j < i + 1); j++) { + --outLen; + *output = base64_chars[ca4[j]]; + ++output; + ++written; + + if (outLen == 0) + return written; + } + while ((i++ < 3)) { + --outLen; + *output = '='; + ++output; + ++written; + + if (outLen == 0) + return written; + } + } + + return written; +} + +static inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +size_t base64Decode(char const* src, size_t len, char* output, size_t outLen) { + if (outLen == 0) + return 0; + + size_t written = 0; + + int i = 0, j = 0, in_ = 0, in_len = len; + unsigned char ca4[4], ca3[3]; + + while (in_len-- && (src[in_] != '=') && is_base64(src[in_])) { + ca4[i++] = src[in_++]; + if (i == 4) { + for (i = 0; i < 4; i++) + ca4[i] = base64_chars.find(ca4[i]); + ca3[0] = (ca4[0] << 2) + ((ca4[1] & 0x30) >> 4); + ca3[1] = ((ca4[1] & 0xf) << 4) + ((ca4[2] & 0x3c) >> 2); + ca3[2] = ((ca4[2] & 0x3) << 6) + ca4[3]; + for (i = 0; (i < 3); i++) { + --outLen; + *output = ca3[i]; + ++output; + ++written; + + if (outLen == 0) + return written; + } + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) + ca4[j] = 0; + for (j = 0; j < 4; j++) + ca4[j] = base64_chars.find(ca4[j]); + ca3[0] = (ca4[0] << 2) + ((ca4[1] & 0x30) >> 4); + ca3[1] = ((ca4[1] & 0xf) << 4) + ((ca4[2] & 0x3c) >> 2); + ca3[2] = ((ca4[2] & 0x3) << 6) + ca4[3]; + for (j = 0; (j < i - 1); j++) { + --outLen; + *output = ca3[j]; + ++output; + ++written; + + if (outLen == 0) + return written; + } + } + + return written; +} + +String hexEncode(char const* data, size_t len) { + std::string res(len * 2, '\0'); + size_t encoded = hexEncode(data, len, &res[0], res.size()); + _unused(encoded); + starAssert(encoded == res.size()); + return move(res); +} + +String base64Encode(char const* data, size_t len) { + std::string res(len * 4 / 3 + 3, '\0'); + size_t encoded = base64Encode(data, len, &res[0], res.size()); + _unused(encoded); + starAssert(encoded <= res.size()); + res.resize(encoded); + return move(res); +} + +String hexEncode(ByteArray const& data) { + return hexEncode(data.ptr(), data.size()); +} + +ByteArray hexDecode(String const& encodedData) { + ByteArray res(encodedData.size() / 2, 0); + size_t decoded = hexDecode(encodedData.utf8Ptr(), encodedData.size(), res.ptr(), res.size()); + _unused(decoded); + starAssert(decoded == res.size()); + return res; +} + +String base64Encode(ByteArray const& data) { + return base64Encode(data.ptr(), data.size()); +} + +ByteArray base64Decode(String const& encodedData) { + ByteArray res(encodedData.size() * 3 / 4, 0); + size_t decoded = base64Decode(encodedData.utf8Ptr(), encodedData.size(), res.ptr(), res.size()); + _unused(decoded); + starAssert(decoded <= res.size()); + res.resize(decoded); + return res; +} + +} diff --git a/source/core/StarEncode.hpp b/source/core/StarEncode.hpp new file mode 100644 index 0000000..62350ae --- /dev/null +++ b/source/core/StarEncode.hpp @@ -0,0 +1,27 @@ +#ifndef STAR_ENCODE_HPP +#define STAR_ENCODE_HPP + +#include "StarString.hpp" +#include "StarByteArray.hpp" + +namespace Star { + +size_t hexEncode(char const* data, size_t len, char* output, size_t outLen = NPos); +size_t hexDecode(char const* src, size_t len, char* output, size_t outLen = NPos); +size_t nibbleDecode(char const* src, size_t len, char* output, size_t outLen = NPos); + +size_t base64Encode(char const* data, size_t len, char* output, size_t outLen = NPos); +size_t base64Decode(char const* src, size_t len, char* output, size_t outLen = NPos); + +String hexEncode(char const* data, size_t len); +String base64Encode(char const* data, size_t len); + +String hexEncode(ByteArray const& data); +ByteArray hexDecode(String const& encodedData); + +String base64Encode(ByteArray const& data); +ByteArray base64Decode(String const& encodedData); + +} + +#endif diff --git a/source/core/StarException.hpp b/source/core/StarException.hpp new file mode 100644 index 0000000..b427511 --- /dev/null +++ b/source/core/StarException.hpp @@ -0,0 +1,100 @@ +#ifndef STAR_EXCEPTION_HPP +#define STAR_EXCEPTION_HPP + +#include "StarFormat.hpp" + +namespace Star { + +class StarException : public std::exception { +public: + template <typename... Args> + static StarException format(char const* fmt, Args const&... args); + + StarException() noexcept; + virtual ~StarException() noexcept; + + explicit StarException(std::string message) noexcept; + explicit StarException(std::exception const& cause) noexcept; + StarException(std::string message, std::exception const& cause) noexcept; + + virtual char const* what() const noexcept override; + + // If the given exception is really StarException, then this will call + // StarException::printException, otherwise just prints std::exception::what. + friend void printException(std::ostream& os, std::exception const& e, bool fullStacktrace); + friend std::string printException(std::exception const& e, bool fullStacktrace); + friend OutputProxy outputException(std::exception const& e, bool fullStacktrace); + +protected: + StarException(char const* type, std::string message) noexcept; + StarException(char const* type, std::string message, std::exception const& cause) noexcept; + +private: + // Takes the ostream to print to, whether to print the full stacktrace. Must + // not bind 'this', may outlive the exception in the case of chained + // exception causes. + function<void(std::ostream&, bool)> m_printException; + + // m_printException will be called without the stack-trace to print + // m_whatBuffer, if the what() method is invoked. + mutable std::string m_whatBuffer; +}; + +void printException(std::ostream& os, std::exception const& e, bool fullStacktrace); +std::string printException(std::exception const& e, bool fullStacktrace); +OutputProxy outputException(std::exception const& e, bool fullStacktrace); + +void printStack(char const* message); + +// Log error and stack-trace and possibly show a dialog box if available, then +// abort. +void fatalError(char const* message, bool showStackTrace); +void fatalException(std::exception const& e, bool showStackTrace); + +#ifdef STAR_DEBUG +#define debugPrintStack() \ + { Star::printStack("Debug: file " STAR_STR(__FILE__) " line " STAR_STR(__LINE__)); } +#define starAssert(COND) \ + { \ + if (COND) \ + ; \ + else \ + Star::fatalError("assert failure in file " STAR_STR(__FILE__) " line " STAR_STR(__LINE__), true); \ + } +#else +#define debugPrintStack() \ + {} +#define starAssert(COND) \ + {} +#endif + +#define STAR_EXCEPTION(ClassName, BaseName) \ + class ClassName : public BaseName { \ + public: \ + template <typename... Args> \ + static ClassName format(char const* fmt, Args const&... args) { \ + return ClassName(strf(fmt, args...)); \ + } \ + ClassName() : BaseName(#ClassName, std::string()) {} \ + explicit ClassName(std::string message) : BaseName(#ClassName, move(message)) {} \ + explicit ClassName(std::exception const& cause) : BaseName(#ClassName, std::string(), cause) {} \ + ClassName(std::string message, std::exception const& cause) : BaseName(#ClassName, move(message), cause) {} \ + \ + protected: \ + ClassName(char const* type, std::string message) : BaseName(type, move(message)) {} \ + ClassName(char const* type, std::string message, std::exception const& cause) \ + : BaseName(type, move(message), cause) {} \ + } + +STAR_EXCEPTION(OutOfRangeException, StarException); +STAR_EXCEPTION(IOException, StarException); +STAR_EXCEPTION(MemoryException, StarException); + +template <typename... Args> +StarException StarException::format(char const* fmt, Args const&... args) { + return StarException(strf(fmt, args...)); +} + +} + +#endif diff --git a/source/core/StarException_unix.cpp b/source/core/StarException_unix.cpp new file mode 100644 index 0000000..db23c7c --- /dev/null +++ b/source/core/StarException_unix.cpp @@ -0,0 +1,136 @@ +#include "StarException.hpp" +#include "StarCasting.hpp" +#include "StarLogging.hpp" + +#include <execinfo.h> +#include <cstdlib> + +namespace Star { + +static size_t const StackLimit = 256; + +typedef pair<Array<void*, StackLimit>, size_t> StackCapture; + +inline StackCapture captureStack() { + StackCapture stackCapture; + stackCapture.second = backtrace(stackCapture.first.ptr(), StackLimit); + return stackCapture; +} + +OutputProxy outputStack(StackCapture stack) { + return OutputProxy([stack = move(stack)](std::ostream & os) { + char** symbols = backtrace_symbols(stack.first.ptr(), stack.second); + for (size_t i = 0; i < stack.second; ++i) { + os << symbols[i]; + if (i + 1 < stack.second) + os << std::endl; + } + + if (stack.second == StackLimit) + os << std::endl << "[Stack Output Limit Reached]"; + + ::free(symbols); + }); +} + +StarException::StarException() noexcept + : StarException(std::string("StarException")) {} + +StarException::~StarException() noexcept {} + +StarException::StarException(std::string message) noexcept + : StarException("StarException", move(message)) {} + +StarException::StarException(std::exception const& cause) noexcept + : StarException("StarException", std::string(), cause) {} + +StarException::StarException(std::string message, std::exception const& cause) noexcept + : StarException("StarException", move(message), cause) {} + +const char* StarException::what() const throw() { + if (m_whatBuffer.empty()) { + std::ostringstream os; + m_printException(os, false); + m_whatBuffer = os.str(); + } + return m_whatBuffer.c_str(); +} + +StarException::StarException(char const* type, std::string message) noexcept { + auto printException = [](std::ostream& os, bool fullStacktrace, char const* type, std::string message, StackCapture stack) { + os << "(" << type << ")"; + if (!message.empty()) + os << " " << message; + + if (fullStacktrace) { + os << std::endl; + os << outputStack(stack); + } + }; + + m_printException = bind(printException, _1, _2, type, move(message), captureStack()); +} + +StarException::StarException(char const* type, std::string message, std::exception const& cause) noexcept + : StarException(type, move(message)) { + auto printException = [](std::ostream& os, bool fullStacktrace, function<void(std::ostream&, bool)> self, function<void(std::ostream&, bool)> cause) { + self(os, fullStacktrace); + os << std::endl << "Caused by: "; + cause(os, fullStacktrace); + }; + + std::function<void(std::ostream&, bool)> printCause; + if (auto starException = as<StarException>(&cause)) { + printCause = bind(starException->m_printException, _1, _2); + } else { + printCause = bind([](std::ostream& os, bool, std::string causeWhat) { + os << "std::exception: " << causeWhat; + }, _1, _2, std::string(cause.what())); + } + + m_printException = bind(printException, _1, _2, m_printException, move(printCause)); +} + +std::string printException(std::exception const& e, bool fullStacktrace) { + std::ostringstream os; + printException(os, e, fullStacktrace); + return os.str(); +} + +void printException(std::ostream& os, std::exception const& e, bool fullStacktrace) { + if (auto starException = as<StarException>(&e)) + starException->m_printException(os, fullStacktrace); + else + os << "std::exception: " << e.what(); +} + +OutputProxy outputException(std::exception const& e, bool fullStacktrace) { + if (auto starException = as<StarException>(&e)) + return OutputProxy(bind(starException->m_printException, _1, fullStacktrace)); + else + return OutputProxy(bind([](std::ostream& os, std::string what) { os << "std::exception: " << what; }, _1, std::string(e.what()))); +} + +void printStack(char const* message) { + Logger::info("Stack Trace (%s)...\n%s", message, outputStack(captureStack())); +} + +void fatalError(char const* message, bool showStackTrace) { + if (showStackTrace) + Logger::error("Fatal Error: %s\n%s", message, outputStack(captureStack())); + else + Logger::error("Fatal Error: %s", message); + + std::abort(); +} + +void fatalException(std::exception const& e, bool showStackTrace) { + if (showStackTrace) + Logger::error("Fatal Exception caught: %s\nCaught at:\n%s", outputException(e, true), outputStack(captureStack())); + else + Logger::error("Fatal Exception caught: %s", outputException(e, showStackTrace)); + + std::abort(); +} + +} diff --git a/source/core/StarException_windows.cpp b/source/core/StarException_windows.cpp new file mode 100644 index 0000000..8892eb2 --- /dev/null +++ b/source/core/StarException_windows.cpp @@ -0,0 +1,253 @@ +#include "StarException.hpp" +#include "StarLogging.hpp" +#include "StarCasting.hpp" +#include "StarString_windows.hpp" + +#include <DbgHelp.h> + +namespace Star { + +struct WindowsSymInitializer { + WindowsSymInitializer() { + if (!SymInitialize(GetCurrentProcess(), NULL, TRUE)) + fatalError("SymInitialize failed", false); + } +}; +static WindowsSymInitializer g_windowsSymInitializer; + +struct DbgHelpLock { + DbgHelpLock() { + InitializeCriticalSection(&criticalSection); + } + + void lock() { + EnterCriticalSection(&criticalSection); + } + + void unlock() { + LeaveCriticalSection(&criticalSection); + } + + CRITICAL_SECTION criticalSection; +}; +static DbgHelpLock g_dbgHelpLock; + +static size_t const StackLimit = 256; + +typedef pair<Array<DWORD64, StackLimit>, size_t> StackCapture; + +inline StackCapture captureStack() { + HANDLE process = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + + CONTEXT context; + DWORD image; + STACKFRAME64 stackFrame; + + memset(&context, 0, sizeof(CONTEXT)); + context.ContextFlags = CONTEXT_FULL; + + ZeroMemory(&stackFrame, sizeof(STACKFRAME64)); + stackFrame.AddrPC.Mode = AddrModeFlat; + stackFrame.AddrReturn.Mode = AddrModeFlat; + stackFrame.AddrFrame.Mode = AddrModeFlat; + stackFrame.AddrStack.Mode = AddrModeFlat; + +#ifdef STAR_ARCHITECTURE_I386 + +#ifdef STAR_COMPILER_MSVC + __asm { + mov [context.Ebp], ebp; + mov [context.Esp], esp; + call next; + next: + pop [context.Eip]; + } +#else + DWORD eip_val = 0; + DWORD esp_val = 0; + DWORD ebp_val = 0; + + __asm__ __volatile__("call 1f\n1: pop %0" : "=g"(eip_val)); + __asm__ __volatile__("movl %%esp, %0" : "=g"(esp_val)); + __asm__ __volatile__("movl %%ebp, %0" : "=g"(ebp_val)); + + context.Eip = eip_val; + context.Esp = esp_val; + context.Ebp = ebp_val; +#endif + + image = IMAGE_FILE_MACHINE_I386; + + stackFrame.AddrPC.Offset = context.Eip; + stackFrame.AddrReturn.Offset = context.Eip; + stackFrame.AddrFrame.Offset = context.Ebp; + stackFrame.AddrStack.Offset = context.Esp; + +#elif defined STAR_ARCHITECTURE_X86_64 + + RtlCaptureContext(&context); + + image = IMAGE_FILE_MACHINE_AMD64; + + stackFrame.AddrPC.Offset = context.Rip; + stackFrame.AddrReturn.Offset = context.Rip; + stackFrame.AddrFrame.Offset = context.Rbp; + stackFrame.AddrStack.Offset = context.Rsp; + +#endif + + g_dbgHelpLock.lock(); + + Array<DWORD64, StackLimit> addresses; + size_t count = 0; + for (size_t i = 0; i < StackLimit; i++) { + if (!StackWalk64(image, process, thread, &stackFrame, &context, NULL, SymFunctionTableAccess64, SymGetModuleBase64, NULL)) + break; + if (stackFrame.AddrPC.Offset == 0) + break; + addresses[i] = stackFrame.AddrPC.Offset; + ++count; + } + + g_dbgHelpLock.unlock(); + + return {addresses, count}; +} + +OutputProxy outputStack(StackCapture stack) { + return OutputProxy([stack = move(stack)](std::ostream & os) { + HANDLE process = GetCurrentProcess(); + g_dbgHelpLock.lock(); + for (size_t i = 0; i < stack.second; ++i) { + char buffer[sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR)]; + PSYMBOL_INFO symbol = (PSYMBOL_INFO)buffer; + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + + DWORD64 displacement = 0; + format(os, "[%i] %p", i, stack.first[i]); + if (SymFromAddr(process, stack.first[i], &displacement, symbol)) + format(os, " %s", symbol->Name); + + if (i + 1 < stack.second) + os << std::endl; + } + + if (stack.second == StackLimit) + os << std::endl << "[Stack Output Limit Reached]"; + + g_dbgHelpLock.unlock(); + }); +} + +StarException::StarException() noexcept : StarException(std::string("StarException")) {} + +StarException::~StarException() noexcept {} + +StarException::StarException(std::string message) noexcept : StarException("StarException", move(message)) {} + +StarException::StarException(std::exception const& cause) noexcept + : StarException("StarException", std::string(), cause) {} + +StarException::StarException(std::string message, std::exception const& cause) noexcept + : StarException("StarException", move(message), cause) {} + +const char* StarException::what() const throw() { + if (m_whatBuffer.empty()) { + std::ostringstream os; + m_printException(os, false); + m_whatBuffer = os.str(); + } + return m_whatBuffer.c_str(); +} + +StarException::StarException(char const* type, std::string message) noexcept { + auto printException = []( + std::ostream& os, bool fullStacktrace, char const* type, std::string message, StackCapture stack) { + os << "(" << type << ")"; + if (!message.empty()) + os << " " << message; + + if (fullStacktrace) { + os << std::endl; + os << outputStack(stack); + } + }; + + m_printException = bind(printException, _1, _2, type, move(message), captureStack()); +} + +StarException::StarException(char const* type, std::string message, std::exception const& cause) noexcept + : StarException(type, move(message)) { + auto printException = [](std::ostream& os, + bool fullStacktrace, + function<void(std::ostream&, bool)> self, + function<void(std::ostream&, bool)> cause) { + self(os, fullStacktrace); + os << std::endl << "Caused by: "; + cause(os, fullStacktrace); + }; + + std::function<void(std::ostream&, bool)> printCause; + if (auto starException = as<StarException>(&cause)) { + printCause = bind(starException->m_printException, _1, _2); + } else { + printCause = bind([](std::ostream& os, bool, std::string causeWhat) { + os << "std::exception: " << causeWhat; + }, _1, _2, std::string(cause.what())); + } + + m_printException = bind(printException, _1, _2, m_printException, move(printCause)); +} + +std::string printException(std::exception const& e, bool fullStacktrace) { + std::ostringstream os; + printException(os, e, fullStacktrace); + return os.str(); +} + +void printException(std::ostream& os, std::exception const& e, bool fullStacktrace) { + if (auto starException = as<StarException>(&e)) + starException->m_printException(os, fullStacktrace); + else + os << "std::exception: " << e.what(); +} + +OutputProxy outputException(std::exception const& e, bool fullStacktrace) { + if (auto starException = as<StarException>(&e)) + return OutputProxy(bind(starException->m_printException, _1, fullStacktrace)); + else + return OutputProxy( + bind([](std::ostream& os, std::string what) { os << "std::exception: " << what; }, _1, std::string(e.what()))); +} + +void printStack(char const* message) { + Logger::info("Stack Trace (%s)...\n%s", message, outputStack(captureStack())); +} + +void fatalError(char const* message, bool showStackTrace) { + std::ostringstream ss; + ss << "Fatal Error: " << message << std::endl; + if (showStackTrace) + ss << outputStack(captureStack()); + + Logger::error(ss.str().c_str()); + MessageBoxW(NULL, stringToUtf16(ss.str()).get(), stringToUtf16("Error").get(), MB_OK | MB_ICONERROR | MB_SYSTEMMODAL); + + std::abort(); +} + +void fatalException(std::exception const& e, bool showStackTrace) { + std::ostringstream ss; + ss << "Fatal Exception caught: " << outputException(e, showStackTrace) << std::endl; + if (showStackTrace) + ss << "Caught at:" << std::endl << outputStack(captureStack()); + + Logger::error(ss.str().c_str()); + MessageBoxW(NULL, stringToUtf16(ss.str()).get(), stringToUtf16("Error").get(), MB_OK | MB_ICONERROR | MB_SYSTEMMODAL); + + std::abort(); +} + +} diff --git a/source/core/StarFile.cpp b/source/core/StarFile.cpp new file mode 100644 index 0000000..738a088 --- /dev/null +++ b/source/core/StarFile.cpp @@ -0,0 +1,240 @@ +#include "StarFile.hpp" +#include "StarFormat.hpp" + +#include <fstream> + +namespace Star { + +void File::makeDirectoryRecursive(String const& fileName) { + auto parent = dirName(fileName); + if (!isDirectory(parent)) + makeDirectoryRecursive(parent); + if (!isDirectory(fileName)) + makeDirectory(fileName); +} + +void File::removeDirectoryRecursive(String const& fileName) { + { + String fileInDir; + bool isDir; + + for (auto const& p : dirList(fileName)) { + std::tie(fileInDir, isDir) = p; + + fileInDir = relativeTo(fileName, fileInDir); + + if (isDir) + removeDirectoryRecursive(fileInDir); + else + remove(fileInDir); + } + } + + remove(fileName); +} + +void File::copy(String const& source, String const& target) { + auto sourceFile = File::open(source, IOMode::Read); + auto targetFile = File::open(target, IOMode::ReadWrite); + + targetFile->resize(0); + + char buf[1024]; + while (!sourceFile->atEnd()) { + size_t r = sourceFile->read(buf, 1024); + targetFile->writeFull(buf, r); + } +} + +FilePtr File::open(const String& filename, IOMode mode) { + auto file = make_shared<File>(filename); + file->open(mode); + return file; +} + +ByteArray File::readFile(String const& filename) { + FilePtr file = File::open(filename, IOMode::Read); + ByteArray bytes; + while (!file->atEnd()) { + char buffer[1024]; + size_t r = file->read(buffer, 1024); + bytes.append(buffer, r); + } + + return bytes; +} + +String File::readFileString(String const& filename) { + FilePtr file = File::open(filename, IOMode::Read); + std::string str; + while (!file->atEnd()) { + char buffer[1024]; + size_t r = file->read(buffer, 1024); + for (size_t i = 0; i < r; ++i) + str.push_back(buffer[i]); + } + + return str; +} + +StreamOffset File::fileSize(String const& filename) { + return File::open(filename, IOMode::Read)->size(); +} + +void File::writeFile(char const* data, size_t len, String const& filename) { + FilePtr file = File::open(filename, IOMode::Write | IOMode::Truncate); + file->writeFull(data, len); +} + +void File::writeFile(ByteArray const& data, String const& filename) { + writeFile(data.ptr(), data.size(), filename); +} + +void File::writeFile(String const& data, String const& filename) { + writeFile(data.utf8Ptr(), data.utf8Size(), filename); +} + +void File::overwriteFileWithRename(ByteArray const& data, String const& filename, String const& newSuffix) { + overwriteFileWithRename(data.ptr(), data.size(), filename, newSuffix); +} + +void File::overwriteFileWithRename(String const& data, String const& filename, String const& newSuffix) { + overwriteFileWithRename(data.utf8Ptr(), data.utf8Size(), filename, newSuffix); +} + +void File::backupFileInSequence(String const& targetFile, unsigned maximumBackups, String const& backupExtensionPrefix) { + for (unsigned i = maximumBackups; i > 0; --i) { + String curExtension = i == 1 ? "" : strf("%s%s", backupExtensionPrefix, i - 1); + String nextExtension = strf("%s%s", backupExtensionPrefix, i); + + if (File::isFile(targetFile + curExtension)) + File::copy(targetFile + curExtension, targetFile + nextExtension); + } +} + +File::File() + : IODevice(IOMode::Closed) { + m_file = 0; +} + +File::File(String filename) + : IODevice(IOMode::Closed), m_filename(move(filename)), m_file(0) {} + +File::~File() { + close(); +} + +StreamOffset File::pos() { + if (!m_file) + throw IOException("pos called on closed File"); + + return ftell(m_file); +} + +void File::seek(StreamOffset offset, IOSeek seekMode) { + if (!m_file) + throw IOException("seek called on closed File"); + + fseek(m_file, offset, seekMode); +} + +StreamOffset File::size() { + return fsize(m_file); +} + +bool File::atEnd() { + if (!m_file) + throw IOException("eof called on closed File"); + + return ftell(m_file) >= fsize(m_file); +} + +size_t File::read(char* data, size_t len) { + if (!m_file) + throw IOException("read called on closed File"); + + if (!isReadable()) + throw IOException("read called on non-readable File"); + + return fread(m_file, data, len); +} + +size_t File::write(const char* data, size_t len) { + if (!m_file) + throw IOException("write called on closed File"); + + if (!isWritable()) + throw IOException("write called on non-writable File"); + + return fwrite(m_file, data, len); +} + +size_t File::readAbsolute(StreamOffset readPosition, char* data, size_t len) { + return pread(m_file, data, len, readPosition); +} + +size_t File::writeAbsolute(StreamOffset writePosition, char const* data, size_t len) { + return pwrite(m_file, data, len, writePosition); +} + +String File::fileName() const { + return m_filename; +} + +void File::setFilename(String filename) { + if (isOpen()) + throw IOException("Cannot call setFilename while File is open"); + m_filename = move(filename); +} + +void File::remove() { + close(); + if (m_filename.empty()) + throw IOException("Cannot remove file, no filename set"); + remove(m_filename); +} + +void File::resize(StreamOffset s) { + bool tempOpen = false; + if (!isOpen()) { + tempOpen = true; + open(mode()); + } + + File::resize(m_file, s); + + if (tempOpen) + close(); +} + +void File::sync() { + if (!m_file) + throw IOException("sync called on closed File"); + + fsync(m_file); +} + +void File::open(IOMode m) { + close(); + if (m_filename.empty()) + throw IOException("Cannot open file, no filename set"); + + m_file = fopen(m_filename.utf8Ptr(), m); + setMode(m); +} + +void File::close() { + if (m_file) + fclose(m_file); + m_file = 0; + setMode(IOMode::Closed); +} + +String File::deviceName() const { + if (m_filename.empty()) + return "<unnamed temp file>"; + else + return m_filename; +} + +} diff --git a/source/core/StarFile.hpp b/source/core/StarFile.hpp new file mode 100644 index 0000000..c75dd4c --- /dev/null +++ b/source/core/StarFile.hpp @@ -0,0 +1,149 @@ +#ifndef STAR_FILE_HPP +#define STAR_FILE_HPP + +#include "StarIODevice.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_CLASS(File); + +// All file methods are thread safe. +class File : public IODevice { +public: + // Converts the passed in path to use the platform specific directory + // separators only (Windows supports '/' just fine, this is mostly for + // uniform appearance). Does *nothing else* (no validity checks, etc). + static String convertDirSeparators(String const& path); + + // All static file operations here throw IOException on error. + // get the current working directory + static String currentDirectory(); + // set the current working directory. + static void changeDirectory(String const& dirName); + static void makeDirectory(String const& dirName); + static void makeDirectoryRecursive(String const& dirName); + + // List all files or directories under given directory. skipDots skips the + // special '.' and '..' entries. Bool value is true for directories. + static List<pair<String, bool>> dirList(String const& dirName, bool skipDots = true); + + // Returns the final component of the given path with no directory separators + static String baseName(String const& fileName); + // All components of the given path minus the final component, separated by + // the directory separator + static String dirName(String const& fileName); + + // Resolve a path relative to another path. If the given path is absolute, + // then the given path is returned unmodified. + static String relativeTo(String const& relativeTo, String const& path); + + // Resolve the given possibly relative path into an absolute path. + static String fullPath(String const& path); + + static String temporaryFileName(); + + // Creates and opens a new ReadWrite temporary file with a real path that can + // be closed and re-opened. Will not be removed automatically. + static FilePtr temporaryFile(); + + // Creates and opens new ReadWrite temporary file and opens it. This file + // has no filename and will be removed on close. + static FilePtr ephemeralFile(); + + // Creates a new temporary directory and reutrns the path. Will not be + // removed automatically. + static String temporaryDirectory(); + + static bool exists(String const& path); + + // Does the file exist and is it a regular file (not a directory or special + // file)? + static bool isFile(String const& path); + // Is the file a directory? + static bool isDirectory(String const& path); + + static void remove(String const& filename); + static void removeDirectoryRecursive(String const& filename); + + // Moves the source file to the target path, overwriting the target path if + // it already exists. + static void rename(String const& source, String const& target); + + // Copies the source file to the target, overwriting the target path if it + // already exists. + static void copy(String const& source, String const& target); + + static ByteArray readFile(String const& filename); + static String readFileString(String const& filename); + static StreamOffset fileSize(String const& filename); + + static void writeFile(char const* data, size_t len, String const& filename); + static void writeFile(ByteArray const& data, String const& filename); + static void writeFile(String const& data, String const& filename); + + // Write a new file, potentially overwriting an existing file, in the safest + // way possible while preserving the old file in the same directory until the + // operation completes. Writes to the same path as the existing file to + // avoid different partition copying. This may clobber anything in the given + // path that matches filename + newSuffix. + static void overwriteFileWithRename(char const* data, size_t len, String const& filename, String const& newSuffix = ".new"); + static void overwriteFileWithRename(ByteArray const& data, String const& filename, String const& newSuffix = ".new"); + static void overwriteFileWithRename(String const& data, String const& filename, String const& newSuffix = ".new"); + + static void backupFileInSequence(String const& targetFile, unsigned maximumBackups, String const& backupExtensionPrefix = "."); + + static FilePtr open(String const& filename, IOMode mode); + + File(); + File(String filename); + virtual ~File(); + + String fileName() const; + void setFilename(String filename); + + // File is closed before removal. + void remove(); + + StreamOffset pos() override; + void seek(StreamOffset pos, IOSeek seek = IOSeek::Absolute) override; + void resize(StreamOffset size) override; + StreamOffset size() override; + bool atEnd() override; + size_t read(char* data, size_t len) override; + size_t write(char const* data, size_t len) override; + + // Do an immediate read / write of an absolute location in the file, without + // modifying the current file cursor. Safe to call in a threaded context + // with other reads and writes, but not safe vs changing the File state like + // open and close. + size_t readAbsolute(StreamOffset readPosition, char* data, size_t len) override; + size_t writeAbsolute(StreamOffset writePosition, char const* data, size_t len) override; + + void open(IOMode mode) override; + void close() override; + + void sync() override; + + String deviceName() const override; + +private: + static void* fopen(char const* filename, IOMode mode); + static void fseek(void* file, StreamOffset offset, IOSeek seek); + static StreamOffset ftell(void* file); + static size_t fread(void* file, char* data, size_t len); + static size_t fwrite(void* file, char const* data, size_t len); + static void fsync(void* file); + static void fclose(void* file); + static StreamOffset fsize(void* file); + static size_t pread(void* file, char* data, size_t len, StreamOffset absPosition); + static size_t pwrite(void* file, char const* data, size_t len, StreamOffset absPosition); + static void resize(void* file, StreamOffset size); + + String m_filename; + void* m_file; +}; + +} + +#endif diff --git a/source/core/StarFile_unix.cpp b/source/core/StarFile_unix.cpp new file mode 100644 index 0000000..aa22bc8 --- /dev/null +++ b/source/core/StarFile_unix.cpp @@ -0,0 +1,295 @@ +#include "StarFile.hpp" +#include "StarFormat.hpp" +#include "StarRandom.hpp" +#include "StarEncode.hpp" + +#include <errno.h> +#include <string.h> +#include <limits.h> +#include <stdlib.h> +#include <dirent.h> +#include <unistd.h> +#include <libgen.h> +#include <fcntl.h> +#include <sys/stat.h> + +#ifdef STAR_SYSTEM_MACOSX +#include <mach-o/dyld.h> +#elif defined STAR_SYSTEM_FREEBSD +#include <sys/types.h> +#include <sys/sysctl.h> +#endif + +namespace Star { + +namespace { + int fdFromHandle(void* ptr) { + return (int)(intptr_t)ptr; + } + + void* handleFromFd(int handle) { + return (void*)(intptr_t)handle; + } +} + +String File::convertDirSeparators(String const& path) { + return path.replace("\\", "/"); +} + +String File::currentDirectory() { + char buffer[PATH_MAX]; + if (::getcwd(buffer, PATH_MAX) == NULL) + throw IOException("getcwd failed"); + + return String(buffer); +} + +void File::changeDirectory(const String& dirName) { + if (::chdir(dirName.utf8Ptr()) != 0) + throw IOException(strf("could not change directory to %s", dirName)); +} + +void File::makeDirectory(String const& dirName) { + if (::mkdir(dirName.utf8Ptr(), 0777) != 0) + throw IOException(strf("could not create directory '%s', %s", dirName, strerror(errno))); +} + +List<pair<String, bool>> File::dirList(const String& dirName, bool skipDots) { + List<std::pair<String, bool>> fileList; + DIR* directory = ::opendir(dirName.utf8Ptr()); + if (directory == NULL) + throw IOException::format("dirList failed on dir: '%s'", dirName); + + for (dirent* entry = ::readdir(directory); entry != NULL; entry = ::readdir(directory)) { + String entryString = entry->d_name; + if (!skipDots || (entryString != "." && entryString != "..")) { + bool isDirectory = false; + if (entry->d_type == DT_DIR) { + isDirectory = true; + } else if (entry->d_type == DT_LNK || entry->d_type == DT_UNKNOWN) { + isDirectory = File::isDirectory(File::relativeTo(dirName, entryString)); + } + fileList.append({entryString, isDirectory}); + } + } + ::closedir(directory); + + return fileList; +} + +String File::baseName(const String& fileName) { + String ret; + + std::string file = fileName.utf8(); + char* fn = new char[file.size() + 1]; + std::copy(file.begin(), file.end(), fn); + fn[file.size()] = 0; + ret = String(::basename(fn)); + delete[] fn; + + return ret; +} + +String File::dirName(const String& fileName) { + String ret; + + std::string file = fileName.utf8(); + char* fn = new char[file.size() + 1]; + std::copy(file.begin(), file.end(), fn); + fn[file.size()] = 0; + ret = String(::dirname(fn)); + delete[] fn; + + return ret; +} + +String File::relativeTo(String const& relativeTo, String const& path) { + if (path.beginsWith("/")) + return path; + return relativeTo.trimEnd("/") + '/' + path; +} + +String File::fullPath(const String& fileName) { + char buffer[PATH_MAX]; + + if (::realpath(fileName.utf8Ptr(), buffer) == NULL) + throw IOException::format("realpath failed on file: '%s' problem path was: '%s'", fileName, buffer); + + return String(buffer); +} + +String File::temporaryFileName() { + return relativeTo(P_tmpdir, strf("starbound.tmpfile.%s", hexEncode(Random::randBytes(16)))); +} + +FilePtr File::temporaryFile() { + return open(temporaryFileName(), IOMode::ReadWrite); +} + +FilePtr File::ephemeralFile() { + auto file = make_shared<File>(); + ByteArray path = ByteArray::fromCStringWithNull(relativeTo(P_tmpdir, "starbound.tmpfile.XXXXXXXX").utf8Ptr()); + auto res = mkstemp(path.ptr()); + if (res < 0) + throw IOException::format("tmpfile error: %s", strerror(errno)); + if (::unlink(path.ptr()) < 0) + throw IOException::format("Could not remove mkstemp file when creating ephemeralFile: %s", strerror(errno)); + file->m_file = handleFromFd(res); + file->setMode(IOMode::ReadWrite); + return file; +} + +String File::temporaryDirectory() { + String dirname = relativeTo(P_tmpdir, strf("starbound.tmpdir.%s", hexEncode(Random::randBytes(16)))); + makeDirectory(dirname); + return dirname; +} + +bool File::exists(String const& path) { + struct stat st_buf; + int status = stat(path.utf8Ptr(), &st_buf); + return status == 0; +} + +bool File::isFile(String const& path) { + struct stat st_buf; + int status = stat(path.utf8Ptr(), &st_buf); + if (status != 0) + return false; + + return S_ISREG(st_buf.st_mode); +} + +bool File::isDirectory(String const& path) { + struct stat st_buf; + int status = stat(path.utf8Ptr(), &st_buf); + if (status != 0) + return false; + + return S_ISDIR(st_buf.st_mode); +} + +void File::remove(String const& filename) { + if (::remove(filename.utf8Ptr()) < 0) + throw IOException::format("remove error: %s", strerror(errno)); +} + +void File::rename(String const& source, String const& target) { + if (::rename(source.utf8Ptr(), target.utf8Ptr()) < 0) + throw IOException::format("rename error: %s", strerror(errno)); +} + +void File::overwriteFileWithRename(char const* data, size_t len, String const& filename, String const& newSuffix) { + String newFile = filename + newSuffix; + writeFile(data, len, newFile); + File::rename(newFile, filename); +} + +void* File::fopen(char const* filename, IOMode mode) { + int oflag = 0; + + if (mode & IOMode::Read && mode & IOMode::Write) + oflag |= O_RDWR | O_CREAT; + else if (mode & IOMode::Read) + oflag |= O_RDONLY; + else if (mode & IOMode::Write) + oflag |= O_WRONLY | O_CREAT; + + if (mode & IOMode::Truncate) + oflag |= O_TRUNC; + + int fd = ::open(filename, oflag, 0666); + if (fd < 0) + throw IOException::format("Error opening file '%s', error: %s", filename, strerror(errno)); + + if (mode & IOMode::Append) { + if (lseek(fd, 0, SEEK_END) < 0) + throw IOException::format("Error opening file '%s', cannot seek: %s", filename, strerror(errno)); + } + + return handleFromFd(fd); +} + +void File::fseek(void* f, StreamOffset offset, IOSeek seekMode) { + auto fd = fdFromHandle(f); + int retCode; + if (seekMode == IOSeek::Relative) + retCode = lseek(fd, offset, SEEK_CUR); + else if (seekMode == IOSeek::Absolute) + retCode = lseek(fd, offset, SEEK_SET); + else + retCode = lseek(fd, offset, SEEK_END); + + if (retCode < 0) + throw IOException::format("Seek error: %s", strerror(errno)); +} + +StreamOffset File::ftell(void* f) { + return lseek(fdFromHandle(f), 0, SEEK_CUR); +} + +size_t File::fread(void* file, char* data, size_t len) { + if (len == 0) + return 0; + + auto fd = fdFromHandle(file); + auto ret = ::read(fd, data, len); + if (ret < 0) { + if (errno == EAGAIN || errno == EINTR) + return 0; + throw IOException::format("Read error: %s", strerror(errno)); + } else { + return ret; + } +} + +size_t File::fwrite(void* file, char const* data, size_t len) { + if (len == 0) + return 0; + + auto fd = fdFromHandle(file); + auto ret = ::write(fd, data, len); + if (ret < 0) { + if (errno == EAGAIN || errno == EINTR) + return 0; + throw IOException::format("Write error: %s", strerror(errno)); + } else { + return ret; + } +} + +void File::fsync(void* file) { + auto fd = fdFromHandle(file); +#ifdef STAR_SYSTEM_LINUX + ::fdatasync(fd); +#else + ::fsync(fd); +#endif +} + +void File::fclose(void* file) { + if (::close(fdFromHandle(file)) < 0) + throw IOException::format("Close error: %s", strerror(errno)); +} + +StreamOffset File::fsize(void* file) { + StreamOffset pos = ftell(file); + StreamOffset size = lseek(fdFromHandle(file), 0, SEEK_END); + lseek(fdFromHandle(file), pos, SEEK_SET); + return size; +} + +size_t File::pread(void* file, char* data, size_t len, StreamOffset position) { + return ::pread(fdFromHandle(file), data, len, position); +} + +size_t File::pwrite(void* file, char const* data, size_t len, StreamOffset position) { + return ::pwrite(fdFromHandle(file), data, len, position); +} + +void File::resize(void* f, StreamOffset size) { + if (::ftruncate(fdFromHandle(f), size) < 0) + throw IOException::format("resize error: %s", strerror(errno)); +} + +} diff --git a/source/core/StarFile_windows.cpp b/source/core/StarFile_windows.cpp new file mode 100644 index 0000000..bb13ffc --- /dev/null +++ b/source/core/StarFile_windows.cpp @@ -0,0 +1,412 @@ +#include "StarFile.hpp" +#include "StarFormat.hpp" +#include "StarRandom.hpp" +#include "StarEncode.hpp" +#include "StarMathCommon.hpp" +#include "StarThread.hpp" + +#include "StarString_windows.hpp" + +#include <errno.h> +#include <io.h> +#include <stdio.h> +#include <windows.h> + +#ifndef MAX_PATH +#define MAX_PATH 1024 +#endif + +namespace Star { + +namespace { + OVERLAPPED makeOverlapped(StreamOffset offset) { + OVERLAPPED overlapped = {}; + overlapped.Offset = offset; + overlapped.OffsetHigh = offset >> 32; + return overlapped; + } +} + +String File::convertDirSeparators(String const& path) { + return path.replace("/", "\\"); +} + +String File::currentDirectory() { + WCHAR buffer[MAX_PATH]; + size_t len = GetCurrentDirectoryW(MAX_PATH, buffer); + if (len == 0) + throw IOException("GetCurrentDirectory failed"); + + return utf16ToString(buffer); +} + +void File::changeDirectory(const String& dirName) { + if (!SetCurrentDirectoryW(stringToUtf16(dirName).get())) + throw IOException(strf("could not change directory to %s", dirName)); +} + +void File::makeDirectory(String const& dirName) { + if (CreateDirectoryW(stringToUtf16(dirName).get(), NULL) == 0) { + auto error = GetLastError(); + throw IOException(strf("could not create directory '%s', %s", dirName, error)); + } +} + +bool File::exists(String const& path) { + WIN32_FIND_DATAW findFileData; + const HANDLE handle = FindFirstFileW(stringToUtf16(path).get(), &findFileData); + if (handle == INVALID_HANDLE_VALUE) + return false; + FindClose(handle); + return true; +} + +bool File::isFile(String const& path) { + WIN32_FIND_DATAW findFileData; + const HANDLE handle = FindFirstFileW(stringToUtf16(path).get(), &findFileData); + if (handle == INVALID_HANDLE_VALUE) + return false; + FindClose(handle); + return (FILE_ATTRIBUTE_DIRECTORY & findFileData.dwFileAttributes) == 0; +} + +bool File::isDirectory(String const& path) { + DWORD attribs = GetFileAttributesW(stringToUtf16(path.trimEnd("\\/")).get()); + if (attribs == INVALID_FILE_ATTRIBUTES) + return false; + return attribs & FILE_ATTRIBUTE_DIRECTORY; +} + +String File::fullPath(const String& path) { + WCHAR buffer[MAX_PATH]; + + size_t fullpath_size; + WCHAR* lpszLastNamePart; + + fullpath_size = GetFullPathNameW(stringToUtf16(path).get(), (DWORD)MAX_PATH, buffer, (WCHAR**)&lpszLastNamePart); + if (0 == fullpath_size) + throw IOException::format("GetFullPathName failed on path: '%s'", path); + if (fullpath_size >= MAX_PATH) + throw IOException::format("GetFullPathName failed on path: '%s'", path); + + return utf16ToString(buffer); +} + +List<std::pair<String, bool>> File::dirList(const String& dirName, bool skipDots) { + List<std::pair<String, bool>> fileList; + WIN32_FIND_DATAW findFileData; + HANDLE hFind; + + hFind = FindFirstFileW(stringToUtf16(File::relativeTo(dirName, "*")).get(), &findFileData); + if (hFind == INVALID_HANDLE_VALUE) + throw IOException(strf("Invalid file handle in dirList of '%s', error is %u", dirName, GetLastError())); + + while (true) { + String entry = utf16ToString(findFileData.cFileName); + if (!skipDots || (entry != "." && entry != "..")) + fileList.append({entry, (FILE_ATTRIBUTE_DIRECTORY & findFileData.dwFileAttributes) != 0}); + if (!FindNextFileW(hFind, &findFileData)) + break; + } + + DWORD dwError = GetLastError(); + FindClose(hFind); + + if ((dwError != ERROR_NO_MORE_FILES) && (dwError != NO_ERROR)) + throw IOException(strf("FindNextFile error in dirList of '%s'. Error is %u", dirName, dwError)); + + return fileList; +} + +String File::baseName(const String& fileName) { + return String(fileName).rextract("\\/"); +} + +String File::dirName(const String& fileName) { + if (fileName == "\\" || fileName == "/") + return "\\"; + + String directory = fileName; + directory.rextract("\\/"); + if (directory.empty()) + return "."; + else + return directory; +} + +String File::relativeTo(String const& relativeTo, String const& path) { + if (path.beginsWith('/') || path.beginsWith('\\') || path.regexMatch("^[a-z]:", false, false)) + return path; + + String finalPath; + if (relativeTo.endsWith('\\') || relativeTo.endsWith('/')) + finalPath = relativeTo.substr(0, relativeTo.size() - 1); + else if (relativeTo.endsWith("\\.") || relativeTo.endsWith("/.")) + finalPath = relativeTo.substr(0, relativeTo.size() - 2); + else + finalPath = relativeTo; + + if (path.beginsWith(".\\") || path.beginsWith("./")) + finalPath += '\\' + path.substr(2); + else + finalPath += '\\' + path; + + return finalPath; +} + +String File::temporaryFileName() { + WCHAR tempPath[MAX_PATH]; + if (!GetTempPathW(MAX_PATH, tempPath)) { + auto error = GetLastError(); + throw IOException(strf("Could not call GetTempPath %s", error)); + } + + return relativeTo(utf16ToString(tempPath), strf("starbound.tmpfile.%s", hexEncode(Random::randBytes(16)))); +} + +FilePtr File::temporaryFile() { + return open(temporaryFileName(), IOMode::ReadWrite); +} + +FilePtr File::ephemeralFile() { + auto file = temporaryFile(); + DeleteFileW(stringToUtf16(file->fileName()).get()); + file->m_filename = ""; + return file; +} + +String File::temporaryDirectory() { + WCHAR tempPath[MAX_PATH]; + if (!GetTempPathW(MAX_PATH, tempPath)) { + auto error = GetLastError(); + throw IOException(strf("Could not call GetTempPath %s", error)); + } + + String dirname = relativeTo(utf16ToString(tempPath), strf("starbound.tmpdir.%s", hexEncode(Random::randBytes(16)))); + makeDirectory(dirname); + return dirname; +} + +void File::remove(String const& filename) { + if (isDirectory(filename)) { + if (!RemoveDirectoryW(stringToUtf16(filename).get())) { + auto error = GetLastError(); + throw IOException(strf("Rename error: %s", error)); + } + } else if (::_wremove(stringToUtf16(filename).get()) < 0) { + auto error = errno; + throw IOException::format("remove error: %s", strerror(error)); + } +} + +void File::rename(String const& source, String const& target) { + bool replace = File::exists(target); + auto temp = target + ".tmp"; + + if (replace) { + if (!DeleteFileW(stringToUtf16(temp).get())) { + auto error = GetLastError(); + if (error != ERROR_FILE_NOT_FOUND) + throw IOException(strf("error deleting existing temp file: %s", error)); + } + if (!MoveFileExW(stringToUtf16(target).get(), stringToUtf16(temp).get(), MOVEFILE_COPY_ALLOWED | MOVEFILE_WRITE_THROUGH)) { + auto error = GetLastError(); + throw IOException(strf("error temporary file '%s': %s", temp, GetLastError())); + } + } + + if (!MoveFileExW(stringToUtf16(source).get(), stringToUtf16(target).get(), MOVEFILE_REPLACE_EXISTING | MOVEFILE_COPY_ALLOWED | MOVEFILE_WRITE_THROUGH)) { + auto error = GetLastError(); + throw IOException(strf("Rename error: %s", error)); + } + + if (replace && !DeleteFileW(stringToUtf16(temp).get())) { + auto error = GetLastError(); + throw IOException(strf("error deleting temp file '%s': %s", temp, GetLastError())); + } +} + +void File::overwriteFileWithRename(char const* data, size_t len, String const& filename, String const& newSuffix) { + String newFile = filename + newSuffix; + + try { + auto file = File::open(newFile, IOMode::Write | IOMode::Truncate); + file->writeFull(data, len); + file->sync(); + file->close(); + + File::rename(newFile, filename); + } catch (IOException const&) { + // HACK: Been having trouble on windows with the write / flush / move + // sequence, try super hard to just write the file non-atomically in case + // of weird file locking problems instead of erroring. + + // Ignore any error on removal of the maybe existing newFile + ::_wremove(stringToUtf16(newFile).get()); + + writeFile(data, len, filename); + } +} + +void* File::fopen(char const* filename, IOMode mode) { + DWORD desiredAccess = 0; + if (mode & IOMode::Read) + desiredAccess |= GENERIC_READ; + if (mode & IOMode::Write) + desiredAccess |= GENERIC_WRITE; + + DWORD creationDisposition = 0; + if (mode & IOMode::Write) + creationDisposition = OPEN_ALWAYS; + else + creationDisposition = OPEN_EXISTING; + + HANDLE file = CreateFileW(stringToUtf16(String(filename)).get(), + desiredAccess, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, NULL, + creationDisposition, 0, NULL); + + if (file == INVALID_HANDLE_VALUE) + throw IOException::format("could not open file '%s' %s", filename, GetLastError()); + + LARGE_INTEGER szero; + szero.QuadPart = 0; + if (!SetFilePointerEx(file, szero, NULL, 0)) { + CloseHandle(file); + throw IOException::format("could not set file pointer in fopen '%s' %s", filename, GetLastError()); + } + + if (mode & IOMode::Truncate) { + if (!SetEndOfFile(file)) { + CloseHandle(file); + throw IOException::format("could not set end of file in fopen '%s' %s", filename, GetLastError()); + } + } + + if (mode & IOMode::Append) { + LARGE_INTEGER size; + if (GetFileSizeEx(file, &size) == 0) { + CloseHandle(file); + throw IOException::format("could not get file size in fopen '%s' %s", filename, GetLastError()); + } + if (!SetFilePointerEx(file, size, NULL, 0)) { + CloseHandle(file); + throw IOException::format("could not set file pointer in fopen '%s' %s", filename, GetLastError()); + } + } + + return (void*)file; +} + +void File::fseek(void* f, StreamOffset offset, IOSeek seekMode) { + HANDLE file = (HANDLE)f; + + LARGE_INTEGER loffset; + loffset.QuadPart = offset; + + if (seekMode == IOSeek::Relative) + SetFilePointerEx(file, loffset, nullptr, FILE_CURRENT); + else if (seekMode == IOSeek::Absolute) + SetFilePointerEx(file, loffset, nullptr, FILE_BEGIN); + else + SetFilePointerEx(file, loffset, nullptr, FILE_END); +} + +StreamOffset File::ftell(void* f) { + HANDLE file = (HANDLE)f; + LARGE_INTEGER pos; + LARGE_INTEGER szero; + szero.QuadPart = 0; + SetFilePointerEx(file, szero, &pos, FILE_CURRENT); + return pos.QuadPart; +} + +size_t File::fread(void* f, char* data, size_t len) { + if (len == 0) + return 0; + + HANDLE file = (HANDLE)f; + + DWORD numRead = 0; + int ret = ReadFile(file, data, len, &numRead, nullptr); + if (ret == 0) { + auto err = GetLastError(); + if (err != ERROR_IO_PENDING) + throw IOException::format("read error %s", err); + } + + return numRead; +} + +size_t File::fwrite(void* f, char const* data, size_t len) { + if (len == 0) + return 0; + + HANDLE file = (HANDLE)f; + + DWORD numWritten = 0; + int ret = WriteFile(file, data, len, &numWritten, nullptr); + if (ret == 0) { + auto err = GetLastError(); + if (err != ERROR_IO_PENDING) + throw IOException::format("write error %s", err); + } + + return numWritten; +} + +void File::fsync(void* f) { + HANDLE file = (HANDLE)f; + if (FlushFileBuffers(file) == 0) + throw IOException::format("fsync error %s", GetLastError()); +} + +void File::fclose(void* f) { + HANDLE file = (HANDLE)f; + CloseHandle(file); +} + +StreamOffset File::fsize(void* f) { + HANDLE file = (HANDLE)f; + LARGE_INTEGER size; + if (GetFileSizeEx(file, &size) == 0) + throw IOException::format("could not get file size in fsize %s", GetLastError()); + return size.QuadPart; +} + +size_t File::pread(void* f, char* data, size_t len, StreamOffset position) { + HANDLE file = (HANDLE)f; + DWORD numRead = 0; + OVERLAPPED overlapped = makeOverlapped(position); + int ret = ReadFile(file, data, len, &numRead, &overlapped); + if (ret == 0) { + auto err = GetLastError(); + if (err != ERROR_IO_PENDING) + throw IOException::format("pread error %s", err); + } + + return numRead; +} + +size_t File::pwrite(void* f, char const* data, size_t len, StreamOffset position) { + HANDLE file = (HANDLE)f; + DWORD numWritten = 0; + OVERLAPPED overlapped = makeOverlapped(position); + int ret = WriteFile(file, data, len, &numWritten, &overlapped); + if (ret == 0) { + auto err = GetLastError(); + if (err != ERROR_IO_PENDING) + throw IOException::format("pwrite error %s", err); + } + + return numWritten; +} + +void File::resize(void* f, StreamOffset size) { + HANDLE file = (HANDLE)f; + LARGE_INTEGER s; + s.QuadPart = size; + SetFilePointerEx(file, s, NULL, 0); + SetEndOfFile(file); +} + +} diff --git a/source/core/StarFlatHashMap.hpp b/source/core/StarFlatHashMap.hpp new file mode 100644 index 0000000..6a8c739 --- /dev/null +++ b/source/core/StarFlatHashMap.hpp @@ -0,0 +1,545 @@ +#ifndef STAR_FLAT_HASH_MAP_HPP +#define STAR_FLAT_HASH_MAP_HPP + +#include <type_traits> + +#include "StarFlatHashTable.hpp" +#include "StarHash.hpp" + +namespace Star { + +template <typename Key, typename Mapped, typename Hash = hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = std::allocator<Key>> +class FlatHashMap { +public: + typedef Key key_type; + typedef Mapped mapped_type; + typedef pair<key_type const, mapped_type> value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef Hash hasher; + typedef Equals key_equal; + typedef Allocator allocator_type; + typedef value_type& reference; + typedef value_type const& const_reference; + typedef value_type* pointer; + typedef value_type const* const_pointer; + +private: + typedef pair<key_type, mapped_type> TableValue; + + struct GetKey { + key_type const& operator()(TableValue const& value) const; + }; + + typedef FlatHashTable<TableValue, key_type, GetKey, Hash, Equals, typename Allocator::template rebind<TableValue>::other> Table; + +public: + struct const_iterator { + typedef std::forward_iterator_tag iterator_category; + typedef typename FlatHashMap::value_type const value_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef value_type& reference; + + bool operator==(const_iterator const& rhs) const; + bool operator!=(const_iterator const& rhs) const; + + const_iterator& operator++(); + const_iterator operator++(int); + + value_type& operator*() const; + value_type* operator->() const; + + typename Table::const_iterator inner; + }; + + struct iterator { + typedef std::forward_iterator_tag iterator_category; + typedef typename FlatHashMap::value_type value_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef value_type& reference; + + bool operator==(iterator const& rhs) const; + bool operator!=(iterator const& rhs) const; + + iterator& operator++(); + iterator operator++(int); + + value_type& operator*() const; + value_type* operator->() const; + + operator const_iterator() const; + + typename Table::iterator inner; + }; + + FlatHashMap(); + explicit FlatHashMap(size_t bucketCount, hasher const& hash = hasher(), + key_equal const& equal = key_equal(), allocator_type const& alloc = allocator_type()); + FlatHashMap(size_t bucketCount, allocator_type const& alloc); + FlatHashMap(size_t bucketCount, hasher const& hash, allocator_type const& alloc); + explicit FlatHashMap(allocator_type const& alloc); + + template <typename InputIt> + FlatHashMap(InputIt first, InputIt last, size_t bucketCount = 0, + hasher const& hash = hasher(), key_equal const& equal = key_equal(), + allocator_type const& alloc = allocator_type()); + template <typename InputIt> + FlatHashMap(InputIt first, InputIt last, size_t bucketCount, allocator_type const& alloc); + template <typename InputIt> + FlatHashMap(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, allocator_type const& alloc); + + FlatHashMap(FlatHashMap const& other); + FlatHashMap(FlatHashMap const& other, allocator_type const& alloc); + FlatHashMap(FlatHashMap&& other); + FlatHashMap(FlatHashMap&& other, allocator_type const& alloc); + + FlatHashMap(initializer_list<value_type> init, size_t bucketCount = 0, + hasher const& hash = hasher(), key_equal const& equal = key_equal(), + allocator_type const& alloc = allocator_type()); + FlatHashMap(initializer_list<value_type> init, size_t bucketCount, allocator_type const& alloc); + FlatHashMap(initializer_list<value_type> init, size_t bucketCount, hasher const& hash, + allocator_type const& alloc); + + FlatHashMap& operator=(FlatHashMap const& other); + FlatHashMap& operator=(FlatHashMap&& other); + FlatHashMap& operator=(initializer_list<value_type> init); + + iterator begin(); + iterator end(); + + const_iterator begin() const; + const_iterator end() const; + + const_iterator cbegin() const; + const_iterator cend() const; + + size_t empty() const; + size_t size() const; + void clear(); + + pair<iterator, bool> insert(value_type const& value); + template <typename T, typename = typename std::enable_if<std::is_constructible<TableValue, T&&>::value>::type> + pair<iterator, bool> insert(T&& value); + iterator insert(const_iterator hint, value_type const& value); + template <typename T, typename = typename std::enable_if<std::is_constructible<TableValue, T&&>::value>::type> + iterator insert(const_iterator hint, T&& value); + template <typename InputIt> + void insert(InputIt first, InputIt last); + void insert(initializer_list<value_type> init); + + template <typename... Args> + pair<iterator, bool> emplace(Args&&... args); + template <typename... Args> + iterator emplace_hint(const_iterator hint, Args&&... args); + + iterator erase(const_iterator pos); + iterator erase(const_iterator first, const_iterator last); + size_t erase(key_type const& key); + + mapped_type& at(key_type const& key); + mapped_type const& at(key_type const& key) const; + + mapped_type& operator[](key_type const& key); + mapped_type& operator[](key_type&& key); + + size_t count(key_type const& key) const; + const_iterator find(key_type const& key) const; + iterator find(key_type const& key); + pair<iterator, iterator> equal_range(key_type const& key); + pair<const_iterator, const_iterator> equal_range(key_type const& key) const; + + void reserve(size_t capacity); + + bool operator==(FlatHashMap const& rhs) const; + bool operator!=(FlatHashMap const& rhs) const; + +private: + Table m_table; +}; + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::GetKey::operator()(TableValue const& value) const -> key_type const& { + return value.first; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator==(const_iterator const& rhs) const { + return inner == rhs.inner; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator!=(const_iterator const& rhs) const { + return inner != rhs.inner; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator++() -> const_iterator& { + ++inner; + return *this; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator++(int) -> const_iterator { + const_iterator copy(*this); + ++*this; + return copy; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator*() const -> value_type& { + return *operator->(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator::operator->() const -> value_type* { + return (value_type*)(&*inner); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator==(iterator const& rhs) const { + return inner == rhs.inner; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator!=(iterator const& rhs) const { + return inner != rhs.inner; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator++() -> iterator& { + ++inner; + return *this; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator++(int) -> iterator { + iterator copy(*this); + operator++(); + return copy; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator*() const -> value_type& { + return *operator->(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator->() const -> value_type* { + return (value_type*)(&*inner); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::iterator::operator typename FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::const_iterator() const { + return const_iterator{inner}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap() + : FlatHashMap(0) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(size_t bucketCount, hasher const& hash, + key_equal const& equal, allocator_type const& alloc) + : m_table(bucketCount, GetKey(), hash, equal, alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(size_t bucketCount, allocator_type const& alloc) + : FlatHashMap(bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(size_t bucketCount, hasher const& hash, + allocator_type const& alloc) + : FlatHashMap(bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(allocator_type const& alloc) + : FlatHashMap(0, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, key_equal const& equal, allocator_type const& alloc) + : FlatHashMap(bucketCount, hash, equal, alloc) { + insert(first, last); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(InputIt first, InputIt last, size_t bucketCount, + allocator_type const& alloc) + : FlatHashMap(first, last, bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, allocator_type const& alloc) + : FlatHashMap(first, last, bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(FlatHashMap const& other) + : FlatHashMap(other, other.m_table.getAllocator()) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(FlatHashMap const& other, allocator_type const& alloc) + : FlatHashMap(alloc) { + operator=(other); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(FlatHashMap&& other) + : FlatHashMap(move(other), other.m_table.getAllocator()) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(FlatHashMap&& other, allocator_type const& alloc) + : FlatHashMap(alloc) { + operator=(move(other)); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(initializer_list<value_type> init, size_t bucketCount, hasher const& hash, + key_equal const& equal, allocator_type const& alloc) + : FlatHashMap(bucketCount, hash, equal, alloc) { + operator=(init); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(initializer_list<value_type> init, size_t bucketCount, + allocator_type const& alloc) + : FlatHashMap(init, bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::FlatHashMap(initializer_list<value_type> init, size_t bucketCount, hasher const& hash, + allocator_type const& alloc) + : FlatHashMap(init, bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator=(FlatHashMap const& other) -> FlatHashMap& { + m_table.clear(); + m_table.reserve(other.size()); + for (auto const& p : other) + m_table.insert(p); + return *this; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator=(FlatHashMap&& other) -> FlatHashMap& { + m_table = move(other.m_table); + return *this; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator=(initializer_list<value_type> init) -> FlatHashMap& { + clear(); + insert(init.begin(), init.end()); + return *this; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::begin() -> iterator { + return iterator{m_table.begin()}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::end() -> iterator { + return iterator{m_table.end()}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::begin() const -> const_iterator { + return const_iterator{m_table.begin()}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::end() const -> const_iterator { + return const_iterator{m_table.end()}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::cbegin() const -> const_iterator { + return begin(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::cend() const -> const_iterator { + return end(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +size_t FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::empty() const { + return m_table.empty(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +size_t FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::size() const { + return m_table.size(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +void FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::clear() { + m_table.clear(); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(value_type const& value) -> pair<iterator, bool> { + auto res = m_table.insert(TableValue(value)); + return {iterator{res.first}, res.second}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename T, typename> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(T&& value) -> pair<iterator, bool> { + auto res = m_table.insert(TableValue(forward<T&&>(value))); + return {iterator{res.first}, res.second}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(const_iterator hint, value_type const& value) -> iterator { + return insert(hint, TableValue(value)); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename T, typename> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(const_iterator, T&& value) -> iterator { + return insert(forward<T&&>(value)).first; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +void FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(InputIt first, InputIt last) { + m_table.reserve(m_table.size() + std::distance(first, last)); + for (auto i = first; i != last; ++i) + m_table.insert(*i); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +void FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::insert(initializer_list<value_type> init) { + insert(init.begin(), init.end()); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename... Args> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::emplace(Args&&... args) -> pair<iterator, bool> { + return insert(TableValue(forward<Args>(args)...)); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +template <typename... Args> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::emplace_hint(const_iterator hint, Args&&... args) -> iterator { + return insert(hint, TableValue(forward<Args>(args)...)); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::erase(const_iterator pos) -> iterator { + return iterator{m_table.erase(pos.inner)}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::erase(const_iterator first, const_iterator last) -> iterator { + return iterator{m_table.erase(first.inner, last.inner)}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +size_t FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::erase(key_type const& key) { + auto i = m_table.find(key); + if (i != m_table.end()) { + m_table.erase(i); + return 1; + } + return 0; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::at(key_type const& key) -> mapped_type& { + auto i = m_table.find(key); + if (i == m_table.end()) + throw std::out_of_range("no such key in FlatHashMap"); + return i->second; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::at(key_type const& key) const -> mapped_type const& { + auto i = m_table.find(key); + if (i == m_table.end()) + throw std::out_of_range("no such key in FlatHashMap"); + return i->second; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator[](key_type const& key) -> mapped_type& { + auto i = m_table.find(key); + if (i != m_table.end()) + return i->second; + return m_table.insert({key, mapped_type()}).first->second; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator[](key_type&& key) -> mapped_type& { + auto i = m_table.find(key); + if (i != m_table.end()) + return i->second; + return m_table.insert({move(key), mapped_type()}).first->second; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +size_t FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::count(key_type const& key) const { + if (m_table.find(key) != m_table.end()) + return 1; + else + return 0; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::find(key_type const& key) const -> const_iterator { + return const_iterator{m_table.find(key)}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::find(key_type const& key) -> iterator { + return iterator{m_table.find(key)}; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::equal_range(key_type const& key) -> pair<iterator, iterator> { + auto i = find(key); + if (i != end()) { + auto j = i; + ++j; + return {i, j}; + } else { + return {i, i}; + } +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +auto FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::equal_range(key_type const& key) const -> pair<const_iterator, const_iterator> { + auto i = find(key); + if (i != end()) { + auto j = i; + ++j; + return {i, j}; + } else { + return {i, i}; + } +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +void FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::reserve(size_t capacity) { + m_table.reserve(capacity); +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator==(FlatHashMap const& rhs) const { + return m_table == rhs.m_table; +} + +template <typename Key, typename Mapped, typename Hash, typename Equals, typename Allocator> +bool FlatHashMap<Key, Mapped, Hash, Equals, Allocator>::operator!=(FlatHashMap const& rhs) const { + return m_table != rhs.m_table; +} + +} + +#endif diff --git a/source/core/StarFlatHashSet.hpp b/source/core/StarFlatHashSet.hpp new file mode 100644 index 0000000..5502a1f --- /dev/null +++ b/source/core/StarFlatHashSet.hpp @@ -0,0 +1,497 @@ +#ifndef STAR_FLAT_HASH_SET_HPP +#define STAR_FLAT_HASH_SET_HPP + +#include "StarFlatHashTable.hpp" +#include "StarHash.hpp" + +namespace Star { + +template <typename Key, typename Hash = hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = std::allocator<Key>> +class FlatHashSet { +public: + typedef Key key_type; + typedef Key value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef Hash hasher; + typedef Equals key_equal; + typedef Allocator allocator_type; + typedef value_type& reference; + typedef value_type const& const_reference; + typedef value_type* pointer; + typedef value_type const* const_pointer; + +private: + struct GetKey { + key_type const& operator()(value_type const& value) const; + }; + + typedef FlatHashTable<Key, Key, GetKey, Hash, Equals, Allocator> Table; + +public: + struct const_iterator { + typedef std::forward_iterator_tag iterator_category; + typedef typename FlatHashSet::value_type const value_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef value_type& reference; + + bool operator==(const_iterator const& rhs) const; + bool operator!=(const_iterator const& rhs) const; + + const_iterator& operator++(); + const_iterator operator++(int); + + value_type& operator*() const; + value_type* operator->() const; + + typename Table::const_iterator inner; + }; + + struct iterator { + typedef std::forward_iterator_tag iterator_category; + typedef typename FlatHashSet::value_type value_type; + typedef ptrdiff_t difference_type; + typedef value_type* pointer; + typedef value_type& reference; + + bool operator==(iterator const& rhs) const; + bool operator!=(iterator const& rhs) const; + + iterator& operator++(); + iterator operator++(int); + + value_type& operator*() const; + value_type* operator->() const; + + operator const_iterator() const; + + typename Table::iterator inner; + }; + + FlatHashSet(); + explicit FlatHashSet(size_t bucketCount, hasher const& hash = hasher(), + key_equal const& equal = key_equal(), allocator_type const& alloc = allocator_type()); + FlatHashSet(size_t bucketCount, allocator_type const& alloc); + FlatHashSet(size_t bucketCount, hasher const& hash, allocator_type const& alloc); + explicit FlatHashSet(allocator_type const& alloc); + + template <typename InputIt> + FlatHashSet(InputIt first, InputIt last, size_t bucketCount = 0, + hasher const& hash = hasher(), key_equal const& equal = key_equal(), + allocator_type const& alloc = allocator_type()); + template <typename InputIt> + FlatHashSet(InputIt first, InputIt last, size_t bucketCount, allocator_type const& alloc); + template <typename InputIt> + FlatHashSet(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, allocator_type const& alloc); + + FlatHashSet(FlatHashSet const& other); + FlatHashSet(FlatHashSet const& other, allocator_type const& alloc); + FlatHashSet(FlatHashSet&& other); + FlatHashSet(FlatHashSet&& other, allocator_type const& alloc); + + FlatHashSet(initializer_list<value_type> init, size_t bucketCount = 0, + hasher const& hash = hasher(), key_equal const& equal = key_equal(), + allocator_type const& alloc = allocator_type()); + FlatHashSet(initializer_list<value_type> init, size_t bucketCount, allocator_type const& alloc); + FlatHashSet(initializer_list<value_type> init, size_t bucketCount, hasher const& hash, + allocator_type const& alloc); + + FlatHashSet& operator=(FlatHashSet const& other); + FlatHashSet& operator=(FlatHashSet&& other); + FlatHashSet& operator=(initializer_list<value_type> init); + + iterator begin(); + iterator end(); + + const_iterator begin() const; + const_iterator end() const; + + const_iterator cbegin() const; + const_iterator cend() const; + + size_t empty() const; + size_t size() const; + void clear(); + + pair<iterator, bool> insert(value_type const& value); + pair<iterator, bool> insert(value_type&& value); + iterator insert(const_iterator hint, value_type const& value); + iterator insert(const_iterator hint, value_type&& value); + template <typename InputIt> + void insert(InputIt first, InputIt last); + void insert(initializer_list<value_type> init); + + template <typename... Args> + pair<iterator, bool> emplace(Args&&... args); + template <typename... Args> + iterator emplace_hint(const_iterator hint, Args&&... args); + + iterator erase(const_iterator pos); + iterator erase(const_iterator first, const_iterator last); + size_t erase(key_type const& key); + + size_t count(key_type const& key) const; + const_iterator find(key_type const& key) const; + iterator find(key_type const& key); + pair<iterator, iterator> equal_range(key_type const& key); + pair<const_iterator, const_iterator> equal_range(key_type const& key) const; + + void reserve(size_t capacity); + + bool operator==(FlatHashSet const& rhs) const; + bool operator!=(FlatHashSet const& rhs) const; + +private: + Table m_table; +}; + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::GetKey::operator()(value_type const& value) const -> key_type const& { + return value; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator==(const_iterator const& rhs) const { + return inner == rhs.inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator!=(const_iterator const& rhs) const { + return inner != rhs.inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator++() -> const_iterator& { + ++inner; + return *this; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator++(int) -> const_iterator { + const_iterator copy(*this); + operator++(); + return copy; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator*() const -> value_type& { + return *inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator::operator->() const -> value_type* { + return &operator*(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator==(iterator const& rhs) const { + return inner == rhs.inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator!=(iterator const& rhs) const { + return inner != rhs.inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator++() -> iterator& { + ++inner; + return *this; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator++(int) -> iterator { + iterator copy(*this); + operator++(); + return copy; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator*() const -> value_type& { + return *inner; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator->() const -> value_type* { + return &operator*(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::iterator::operator typename FlatHashSet<Key, Hash, Equals, Allocator>::const_iterator() const { + return const_iterator{inner}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet() + : FlatHashSet(0) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(size_t bucketCount, hasher const& hash, + key_equal const& equal, allocator_type const& alloc) + : m_table(bucketCount, GetKey(), hash, equal, alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(size_t bucketCount, allocator_type const& alloc) + : FlatHashSet(bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(size_t bucketCount, hasher const& hash, + allocator_type const& alloc) + : FlatHashSet(bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(allocator_type const& alloc) + : FlatHashSet(0, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, key_equal const& equal, allocator_type const& alloc) + : FlatHashSet(bucketCount, hash, equal, alloc) { + insert(first, last); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(InputIt first, InputIt last, size_t bucketCount, + allocator_type const& alloc) + : FlatHashSet(first, last, bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(InputIt first, InputIt last, size_t bucketCount, + hasher const& hash, allocator_type const& alloc) + : FlatHashSet(first, last, bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(FlatHashSet const& other) + : FlatHashSet(other, other.m_table.getAllocator()) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(FlatHashSet const& other, allocator_type const& alloc) + : FlatHashSet(alloc) { + operator=(other); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(FlatHashSet&& other) + : FlatHashSet(move(other), other.m_table.getAllocator()) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(FlatHashSet&& other, allocator_type const& alloc) + : FlatHashSet(alloc) { + operator=(move(other)); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(initializer_list<value_type> init, size_t bucketCount, + hasher const& hash, key_equal const& equal, allocator_type const& alloc) + : FlatHashSet(bucketCount, hash, equal, alloc) { + operator=(init); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(initializer_list<value_type> init, size_t bucketCount, allocator_type const& alloc) + : FlatHashSet(init, bucketCount, hasher(), key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>::FlatHashSet(initializer_list<value_type> init, size_t bucketCount, + hasher const& hash, allocator_type const& alloc) + : FlatHashSet(init, bucketCount, hash, key_equal(), alloc) {} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>& FlatHashSet<Key, Hash, Equals, Allocator>::operator=(FlatHashSet const& other) { + m_table.clear(); + m_table.reserve(other.size()); + for (auto const& p : other) + m_table.insert(p); + return *this; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>& FlatHashSet<Key, Hash, Equals, Allocator>::operator=(FlatHashSet&& other) { + m_table = move(other.m_table); + return *this; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +FlatHashSet<Key, Hash, Equals, Allocator>& FlatHashSet<Key, Hash, Equals, Allocator>::operator=(initializer_list<value_type> init) { + clear(); + insert(init.begin(), init.end()); + return *this; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::begin() -> iterator { + return iterator{m_table.begin()}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::end() -> iterator { + return iterator{m_table.end()}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::begin() const -> const_iterator { + return const_iterator{m_table.begin()}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::end() const -> const_iterator { + return const_iterator{m_table.end()}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::cbegin() const -> const_iterator { + return begin(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::cend() const -> const_iterator { + return end(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +size_t FlatHashSet<Key, Hash, Equals, Allocator>::empty() const { + return m_table.empty(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +size_t FlatHashSet<Key, Hash, Equals, Allocator>::size() const { + return m_table.size(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +void FlatHashSet<Key, Hash, Equals, Allocator>::clear() { + m_table.clear(); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::insert(value_type const& value) -> pair<iterator, bool> { + auto res = m_table.insert(value); + return {iterator{res.first}, res.second}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::insert(value_type&& value) -> pair<iterator, bool> { + auto res = m_table.insert(move(value)); + return {iterator{res.first}, res.second}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::insert(const_iterator i, value_type const& value) -> iterator { + return insert(i, value_type(value)); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::insert(const_iterator, value_type&& value) -> iterator { + return insert(move(value)).first; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename InputIt> +void FlatHashSet<Key, Hash, Equals, Allocator>::insert(InputIt first, InputIt last) { + m_table.reserve(m_table.size() + std::distance(first, last)); + for (auto i = first; i != last; ++i) + m_table.insert(*i); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +void FlatHashSet<Key, Hash, Equals, Allocator>::insert(initializer_list<value_type> init) { + insert(init.begin(), init.end()); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename... Args> +auto FlatHashSet<Key, Hash, Equals, Allocator>::emplace(Args&&... args) -> pair<iterator, bool> { + return insert(value_type(forward<Args>(args)...)); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +template <typename... Args> +auto FlatHashSet<Key, Hash, Equals, Allocator>::emplace_hint(const_iterator i, Args&&... args) -> iterator { + return insert(i, value_type(forward<Args>(args)...)); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::erase(const_iterator pos) -> iterator { + return iterator{m_table.erase(pos.inner)}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::erase(const_iterator first, const_iterator last) -> iterator { + return iterator{m_table.erase(first.inner, last.inner)}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +size_t FlatHashSet<Key, Hash, Equals, Allocator>::erase(key_type const& key) { + auto i = m_table.find(key); + if (i != m_table.end()) { + m_table.erase(i); + return 1; + } + return 0; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +size_t FlatHashSet<Key, Hash, Equals, Allocator>::count(Key const& key) const { + if (m_table.find(key) != m_table.end()) + return 1; + else + return 0; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::find(key_type const& key) const -> const_iterator { + return const_iterator{m_table.find(key)}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::find(key_type const& key) -> iterator { + return iterator{m_table.find(key)}; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::equal_range(key_type const& key) -> pair<iterator, iterator> { + auto i = find(key); + if (i != end()) { + auto j = i; + ++j; + return {i, j}; + } else { + return {i, i}; + } +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +auto FlatHashSet<Key, Hash, Equals, Allocator>::equal_range(key_type const& key) const -> pair<const_iterator, const_iterator> { + auto i = find(key); + if (i != end()) { + auto j = i; + ++j; + return {i, j}; + } else { + return {i, i}; + } +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +void FlatHashSet<Key, Hash, Equals, Allocator>::reserve(size_t capacity) { + m_table.reserve(capacity); +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::operator==(FlatHashSet const& rhs) const { + return m_table == rhs.m_table; +} + +template <typename Key, typename Hash, typename Equals, typename Allocator> +bool FlatHashSet<Key, Hash, Equals, Allocator>::operator!=(FlatHashSet const& rhs) const { + return m_table != rhs.m_table; +} + +} + +#endif diff --git a/source/core/StarFlatHashTable.hpp b/source/core/StarFlatHashTable.hpp new file mode 100644 index 0000000..a13e08a --- /dev/null +++ b/source/core/StarFlatHashTable.hpp @@ -0,0 +1,557 @@ +#ifndef STAR_FLAT_HASH_TABLE_HPP +#define STAR_FLAT_HASH_TABLE_HPP + +#include <vector> + +#include "StarConfig.hpp" + +namespace Star { + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +struct FlatHashTable { +private: + static size_t const EmptyHashValue = 0; + static size_t const EndHashValue = 1; + static size_t const FilledHashBit = (size_t)1 << (sizeof(size_t) * 8 - 1); + + struct Bucket { + Bucket(); + ~Bucket(); + + Bucket(Bucket const& rhs); + Bucket(Bucket&& rhs); + + Bucket& operator=(Bucket const& rhs); + Bucket& operator=(Bucket&& rhs); + + void setFilled(size_t hash, Value value); + void setEmpty(); + void setEnd(); + + Value const* valuePtr() const; + Value* valuePtr(); + bool isEmpty() const; + bool isEnd() const; + + union { + Value value; + }; + size_t hash; + }; + + typedef std::vector<Bucket, typename Allocator::template rebind<Bucket>::other> Buckets; + +public: + struct const_iterator { + bool operator==(const_iterator const& rhs) const; + bool operator!=(const_iterator const& rhs) const; + + const_iterator& operator++(); + const_iterator operator++(int); + + Value const& operator*() const; + Value const* operator->() const; + + Bucket const* current; + }; + + struct iterator { + bool operator==(iterator const& rhs) const; + bool operator!=(iterator const& rhs) const; + + iterator& operator++(); + iterator operator++(int); + + Value& operator*() const; + Value* operator->() const; + + operator const_iterator() const; + + Bucket* current; + }; + + FlatHashTable(size_t bucketCount, GetKey const& getKey, Hash const& hash, Equals const& equal, Allocator const& alloc); + + iterator begin(); + iterator end(); + + const_iterator begin() const; + const_iterator end() const; + + size_t empty() const; + size_t size() const; + void clear(); + + pair<iterator, bool> insert(Value value); + + iterator erase(const_iterator pos); + iterator erase(const_iterator first, const_iterator last); + + const_iterator find(Key const& key) const; + iterator find(Key const& key); + + void reserve(size_t capacity); + Allocator getAllocator() const; + + bool operator==(FlatHashTable const& rhs) const; + bool operator!=(FlatHashTable const& rhs) const; + +private: + static constexpr size_t MinCapacity = 8; + static constexpr double MaxFillLevel = 0.7; + + // Scans for the next bucket value that is non-empty + static Bucket* scan(Bucket* p); + static Bucket const* scan(Bucket const* p); + + size_t hashBucket(size_t hash) const; + size_t bucketError(size_t current, size_t target) const; + void checkCapacity(size_t additionalCapacity); + + Buckets m_buckets; + size_t m_filledCount; + + GetKey m_getKey; + Hash m_hash; + Equals m_equals; +}; + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::Bucket() { + this->hash = EmptyHashValue; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::~Bucket() { + if (auto s = valuePtr()) + s->~Value(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::Bucket(Bucket const& rhs) { + this->hash = rhs.hash; + if (auto o = rhs.valuePtr()) + new (&this->value) Value(*o); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::Bucket(Bucket&& rhs) { + this->hash = rhs.hash; + if (auto o = rhs.valuePtr()) + new (&this->value) Value(move(*o)); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::operator=(Bucket const& rhs) -> Bucket& { + if (auto o = rhs.valuePtr()) { + if (auto s = valuePtr()) + *s = *o; + else + new (&this->value) Value(*o); + } else { + if (auto s = valuePtr()) + s->~Value(); + } + this->hash = rhs.hash; + return *this; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::operator=(Bucket&& rhs) -> Bucket& { + if (auto o = rhs.valuePtr()) { + if (auto s = valuePtr()) + *s = move(*o); + else + new (&this->value) Value(move(*o)); + } else { + if (auto s = valuePtr()) + s->~Value(); + } + this->hash = rhs.hash; + return *this; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::setFilled(size_t hash, Value value) { + if (auto s = valuePtr()) + *s = move(value); + else + new (&this->value) Value(move(value)); + this->hash = hash | FilledHashBit; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::setEmpty() { + if (auto s = valuePtr()) + s->~Value(); + this->hash = EmptyHashValue; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::setEnd() { + if (auto s = valuePtr()) + s->~Value(); + this->hash = EndHashValue; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +Value const* FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::valuePtr() const { + if (hash & FilledHashBit) + return &this->value; + return nullptr; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +Value* FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::valuePtr() { + if (hash & FilledHashBit) + return &this->value; + return nullptr; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::isEmpty() const { + return this->hash == EmptyHashValue; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::Bucket::isEnd() const { + return this->hash == EndHashValue; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator==(const_iterator const& rhs) const { + return current == rhs.current; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator!=(const_iterator const& rhs) const { + return current != rhs.current; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator++() -> const_iterator& { + current = scan(++current); + return *this; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator++(int) -> const_iterator { + const_iterator copy(*this); + operator++(); + return copy; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator*() const -> Value const& { + return *operator->(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator::operator->() const -> Value const* { + return current->valuePtr(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator==(iterator const& rhs) const { + return current == rhs.current; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator!=(iterator const& rhs) const { + return current != rhs.current; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator++() -> iterator& { + current = scan(++current); + return *this; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator++(int) -> iterator { + iterator copy(*this); + operator++(); + return copy; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator*() const -> Value& { + return *operator->(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator->() const -> Value* { + return current->valuePtr(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::iterator::operator typename FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::const_iterator() const { + return const_iterator{current}; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::FlatHashTable(size_t bucketCount, + GetKey const& getKey, Hash const& hash, Equals const& equal, Allocator const& alloc) + : m_buckets(alloc), m_filledCount(0), m_getKey(getKey), + m_hash(hash), m_equals(equal) { + if (bucketCount != 0) + checkCapacity(bucketCount); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::begin() -> iterator { + if (m_buckets.empty()) + return end(); + return iterator{scan(m_buckets.data())}; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::end() -> iterator { + return iterator{m_buckets.data() + m_buckets.size() - 1}; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::begin() const -> const_iterator { + return const_cast<FlatHashTable*>(this)->begin(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::end() const -> const_iterator { + return const_cast<FlatHashTable*>(this)->end(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +size_t FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::empty() const { + return m_filledCount == 0; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +size_t FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::size() const { + return m_filledCount; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::clear() { + if (m_buckets.empty()) + return; + + for (size_t i = 0; i < m_buckets.size() - 1; ++i) + m_buckets[i].setEmpty(); + m_filledCount = 0; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::insert(Value value) -> pair<iterator, bool> { + if (m_buckets.empty() || m_filledCount + 1 > (m_buckets.size() - 1) * MaxFillLevel) + checkCapacity(1); + + size_t hash = m_hash(m_getKey(value)) | FilledHashBit; + size_t targetBucket = hashBucket(hash); + size_t currentBucket = targetBucket; + size_t insertedBucket = NPos; + + while (true) { + auto& target = m_buckets[currentBucket]; + if (auto entryValue = target.valuePtr()) { + if (target.hash == hash && m_equals(m_getKey(*entryValue), m_getKey(value))) + return make_pair(iterator{m_buckets.data() + currentBucket}, false); + + size_t entryTargetBucket = hashBucket(target.hash); + size_t entryError = bucketError(currentBucket, entryTargetBucket); + size_t addError = bucketError(currentBucket, targetBucket); + if (addError > entryError) { + if (insertedBucket == NPos) + insertedBucket = currentBucket; + + swap(value, *entryValue); + swap(hash, target.hash); + targetBucket = entryTargetBucket; + } + currentBucket = hashBucket(currentBucket + 1); + + } else { + target.setFilled(hash, move(value)); + ++m_filledCount; + if (insertedBucket == NPos) + insertedBucket = currentBucket; + + return make_pair(iterator{m_buckets.data() + insertedBucket}, true); + } + } +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::erase(const_iterator pos) -> iterator { + size_t bucketIndex = pos.current - m_buckets.data(); + size_t currentBucketIndex = bucketIndex; + auto currentBucket = &m_buckets[currentBucketIndex]; + + while (true) { + size_t nextBucketIndex = hashBucket(currentBucketIndex + 1); + auto nextBucket = &m_buckets[nextBucketIndex]; + if (auto nextPtr = nextBucket->valuePtr()) { + if (bucketError(nextBucketIndex, nextBucket->hash) > 0) { + currentBucket->hash = nextBucket->hash; + *currentBucket->valuePtr() = move(*nextPtr); + currentBucketIndex = nextBucketIndex; + currentBucket = nextBucket; + } else { + break; + } + } else { + break; + } + } + + m_buckets[currentBucketIndex].setEmpty(); + --m_filledCount; + + return iterator{scan(m_buckets.data() + bucketIndex)}; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::erase(const_iterator first, const_iterator last) -> iterator { + while (first != last) + first = erase(first); + return iterator{(Bucket*)first.current}; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::find(Key const& key) const -> const_iterator { + return const_cast<FlatHashTable*>(this)->find(key); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::find(Key const& key) -> iterator { + if (m_buckets.empty()) + return end(); + + size_t hash = m_hash(key) | FilledHashBit; + size_t targetBucket = hashBucket(hash); + size_t currentBucket = targetBucket; + while (true) { + auto& bucket = m_buckets[currentBucket]; + if (auto value = bucket.valuePtr()) { + if (bucket.hash == hash && m_equals(m_getKey(*value), key)) + return iterator{m_buckets.data() + currentBucket}; + + size_t entryError = bucketError(currentBucket, bucket.hash); + size_t findError = bucketError(currentBucket, targetBucket); + + if (findError > entryError) + return end(); + + currentBucket = hashBucket(currentBucket + 1); + + } else { + return end(); + } + } +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::reserve(size_t capacity) { + if (capacity > m_filledCount) + checkCapacity(capacity - m_filledCount); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +Allocator FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::getAllocator() const { + return m_buckets.get_allocator(); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::operator==(FlatHashTable const& rhs) const { + if (size() != rhs.size()) + return false; + + auto i = begin(); + auto j = rhs.begin(); + auto e = end(); + + while (i != e) { + if (*i != *j) + return false; + ++i; + ++j; + } + + return true; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +bool FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::operator!=(FlatHashTable const& rhs) const { + return !operator==(rhs); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +constexpr size_t FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::MinCapacity; + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +constexpr double FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::MaxFillLevel; + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::scan(Bucket* p) -> Bucket* { + while (p->isEmpty()) + ++p; + return p; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +auto FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::scan(Bucket const* p) -> Bucket const* { + while (p->isEmpty()) + ++p; + return p; +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +size_t FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::hashBucket(size_t hash) const { + return hash & (m_buckets.size() - 2); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +size_t FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::bucketError(size_t current, size_t target) const { + return hashBucket(current - target); +} + +template <typename Value, typename Key, typename GetKey, typename Hash, typename Equals, typename Allocator> +void FlatHashTable<Value, Key, GetKey, Hash, Equals, Allocator>::checkCapacity(size_t additionalCapacity) { + if (additionalCapacity == 0) + return; + + size_t newSize; + if (!m_buckets.empty()) + newSize = m_buckets.size() - 1; + else + newSize = MinCapacity; + + while ((double)(m_filledCount + additionalCapacity) / (double)newSize > MaxFillLevel) + newSize *= 2; + + if (newSize == m_buckets.size() - 1) + return; + + Buckets oldBuckets; + swap(m_buckets, oldBuckets); + + // Leave an extra end entry when allocating buckets, so that iterators are + // simpler and can simply iterate until they find something that is not an + // empty entry. + m_buckets.resize(newSize + 1); + while (m_buckets.capacity() > newSize * 2 + 1) { + newSize *= 2; + m_buckets.resize(newSize + 1); + } + m_buckets[newSize].setEnd(); + + m_filledCount = 0; + + for (auto& entry : oldBuckets) { + if (auto ptr = entry.valuePtr()) + insert(move(*ptr)); + } +} + +} + +#endif diff --git a/source/core/StarFont.cpp b/source/core/StarFont.cpp new file mode 100644 index 0000000..66767aa --- /dev/null +++ b/source/core/StarFont.cpp @@ -0,0 +1,123 @@ +#include "StarFont.hpp" +#include "StarFile.hpp" +#include "StarFormat.hpp" + +#include <ft2build.h> +#include FT_FREETYPE_H + +namespace Star { + +struct FTContext { + FT_Library library; + + FTContext() { + library = nullptr; + if (FT_Init_FreeType(&library)) + throw FontException("Could not initialize freetype library."); + } + + ~FTContext() { + if (library) { + FT_Done_FreeType(library); + library = nullptr; + } + } +}; + +FTContext ftContext; + +struct FontImpl { + FT_Face face; +}; + +FontPtr Font::loadTrueTypeFont(String const& fileName, unsigned pixelSize) { + return loadTrueTypeFont(make_shared<ByteArray>(File::readFile(fileName)), pixelSize); +} + +FontPtr Font::loadTrueTypeFont(ByteArrayConstPtr const& bytes, unsigned pixelSize) { + FontPtr font = make_shared<Font>(); + font->m_fontBuffer = bytes; + + shared_ptr<FontImpl> fontImpl = make_shared<FontImpl>(); + if (FT_New_Memory_Face( + ftContext.library, (FT_Byte const*)font->m_fontBuffer->ptr(), font->m_fontBuffer->size(), 0, &fontImpl->face)) + throw FontException("Could not load font from buffer"); + + font->m_fontImpl = fontImpl; + font->setPixelSize(pixelSize); + + return font; +} + +Font::Font() { + m_pixelSize = 0; +} + +FontPtr Font::clone() const { + return Font::loadTrueTypeFont(m_fontBuffer, m_pixelSize); +} + +void Font::setPixelSize(unsigned pixelSize) { + if (pixelSize == 0) { + pixelSize = 1; + } + + if (m_pixelSize == pixelSize) + return; + + if (FT_Set_Pixel_Sizes(m_fontImpl->face, pixelSize, 0)) + throw FontException(strf("Cannot set font pixel size to: %s", pixelSize)); + m_pixelSize = pixelSize; +} + +unsigned Font::height() const { + return m_pixelSize; +} + +unsigned Font::width(String::Char c) { + if (auto width = m_widthCache.maybe({c, m_pixelSize})) { + return *width; + } else { + FT_Load_Char(m_fontImpl->face, c, FT_LOAD_DEFAULT); + unsigned newWidth = (m_fontImpl->face->glyph->advance.x + 32) / 64; + m_widthCache.insert({c, m_pixelSize}, newWidth); + return newWidth; + } +} + +Image Font::render(String::Char c) { + if (!m_fontImpl) + throw FontException("Font::render called on uninitialzed font."); + + FT_UInt glyph_index = FT_Get_Char_Index(m_fontImpl->face, c); + if (FT_Load_Glyph(m_fontImpl->face, glyph_index, FT_LOAD_DEFAULT) != 0) + return {}; + + /* convert to an anti-aliased bitmap */ + if (FT_Render_Glyph(m_fontImpl->face->glyph, FT_RENDER_MODE_NORMAL) != 0) + return {}; + + FT_GlyphSlot slot = m_fontImpl->face->glyph; + + int width = (slot->advance.x + 32) / 64; + int height = m_pixelSize; + + Image image(width, height, PixelFormat::RGBA32); + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + int bx = x; + int by = y + slot->bitmap_top - m_fontImpl->face->size->metrics.ascender / 64; + if (bx >= 0 && by >= 0 && bx < (int)slot->bitmap.width && by < (int)slot->bitmap.rows) { + unsigned char* p = slot->bitmap.buffer + by * slot->bitmap.pitch; + unsigned char val = *(p + bx); + image.set(x, height - y - 1, Vec4B(255, 255, 255, val)); + } else { + image.set(x, height - y - 1, Vec4B(255, 255, 255, 0)); + } + } + } + + return image; +} + +} diff --git a/source/core/StarFont.hpp b/source/core/StarFont.hpp new file mode 100644 index 0000000..954f27a --- /dev/null +++ b/source/core/StarFont.hpp @@ -0,0 +1,48 @@ +#ifndef STAR_FONT_HPP +#define STAR_FONT_HPP + +#include "StarString.hpp" +#include "StarImage.hpp" +#include "StarByteArray.hpp" +#include "StarMap.hpp" + +namespace Star { + +STAR_EXCEPTION(FontException, StarException); + +STAR_STRUCT(FontImpl); +STAR_CLASS(Font); + +class Font { +public: + static FontPtr loadTrueTypeFont(String const& fileName, unsigned pixelSize = 12); + static FontPtr loadTrueTypeFont(ByteArrayConstPtr const& bytes, unsigned pixelSize = 12); + + Font(); + + Font(Font const&) = delete; + Font const& operator=(Font const&) = delete; + + // Create a new font from the same data + FontPtr clone() const; + + void setPixelSize(unsigned pixelSize); + + unsigned height() const; + unsigned width(String::Char c); + + // May return empty image on unrenderable character (Normally, this will + // render a box, but if there is an internal freetype error this may return + // an empty image). + Image render(String::Char c); + +private: + FontImplPtr m_fontImpl; + ByteArrayConstPtr m_fontBuffer; + unsigned m_pixelSize; + HashMap<pair<String::Char, unsigned>, unsigned> m_widthCache; +}; + +} + +#endif diff --git a/source/core/StarFormat.hpp b/source/core/StarFormat.hpp new file mode 100644 index 0000000..bb0327e --- /dev/null +++ b/source/core/StarFormat.hpp @@ -0,0 +1,100 @@ +#ifndef STAR_FORMAT_HPP +#define STAR_FORMAT_HPP + +#include "StarMemory.hpp" + +namespace Star { + +struct FormatException : public std::exception { + FormatException(std::string what) : whatmsg(move(what)) {} + + char const* what() const noexcept override { + return whatmsg.c_str(); + } + + std::string whatmsg; +}; + +} + +#define TINYFORMAT_ERROR(reason) throw Star::FormatException(reason) + +#include "tinyformat.h" + +namespace Star { + +template <typename... Args> +void format(std::ostream& out, char const* fmt, Args const&... args) { + tinyformat::format(out, fmt, args...); +} + +// Automatically flushes, use format to avoid flushing. +template <typename... Args> +void coutf(char const* fmt, Args const&... args) { + format(std::cout, fmt, args...); + std::cout.flush(); +} + +// Automatically flushes, use format to avoid flushing. +template <typename... Args> +void cerrf(char const* fmt, Args const&... args) { + format(std::cerr, fmt, args...); + std::cerr.flush(); +} + +template <typename... Args> +std::string strf(char const* fmt, Args const&... args) { + std::ostringstream os; + format(os, fmt, args...); + return os.str(); +} + +namespace OutputAnyDetail { + template<typename T, typename CharT, typename Traits> + std::basic_ostream<CharT, Traits>& output(std::basic_ostream<CharT, Traits>& os, T const& t) { + return os << "<type " << typeid(T).name() << " at address: " << &t << ">"; + } + + namespace Operator { + template<typename T, typename CharT, typename Traits> + std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os, T const& t) { + return output(os, t); + } + } + + template <typename T> + struct Wrapper { + T const& wrapped; + }; + + template <typename T> + std::ostream& operator<<(std::ostream& os, Wrapper<T> const& wrapper) { + using namespace Operator; + return os << wrapper.wrapped; + } +} + +// Wraps a type so that is printable no matter what.. If no operator<< is +// defined for a type, then will print <type [typeid] at address: [address]> +template <typename T> +OutputAnyDetail::Wrapper<T> outputAny(T const& t) { + return OutputAnyDetail::Wrapper<T>{t}; +} + +struct OutputProxy { + typedef function<void(std::ostream&)> PrintFunction; + + OutputProxy(PrintFunction p) + : print(move(p)) {} + + PrintFunction print; +}; + +inline std::ostream& operator<<(std::ostream& os, OutputProxy const& p) { + p.print(os); + return os; +} + +} + +#endif diff --git a/source/core/StarFormattedJson.cpp b/source/core/StarFormattedJson.cpp new file mode 100644 index 0000000..f31f3ba --- /dev/null +++ b/source/core/StarFormattedJson.cpp @@ -0,0 +1,674 @@ +#include "StarFormattedJson.hpp" +#include "StarJsonBuilder.hpp" +#include "StarLexicalCast.hpp" + +namespace Star { + +class FormattedJsonBuilderStream : public JsonStream { +public: + virtual void beginObject(); + virtual void objectKey(String::Char const* s, size_t len); + virtual void endObject(); + + virtual void beginArray(); + virtual void endArray(); + + virtual void putString(String::Char const* s, size_t len); + virtual void putDouble(String::Char const* s, size_t len); + virtual void putInteger(String::Char const* s, size_t len); + virtual void putBoolean(bool b); + virtual void putNull(); + + virtual void putWhitespace(String::Char const* s, size_t len); + virtual void putComma(); + virtual void putColon(); + + FormattedJson takeTop(); + +private: + void push(FormattedJson const& v); + FormattedJson pop(); + FormattedJson& current(); + void putValue(Json const& value, Maybe<String> formatting = {}); + + Maybe<FormattedJson> m_root; + List<FormattedJson> m_stack; +}; + +template <> +class JsonStreamer<FormattedJson> { +public: + static void toJsonStream(FormattedJson const& val, JsonStream& stream, bool sort); +}; + +ValueElement::ValueElement(FormattedJson const& json) : value(make_shared<FormattedJson>(json)) {} + +bool ValueElement::operator==(ValueElement const& v) const { + return *value == *v.value; +} + +bool ObjectKeyElement::operator==(ObjectKeyElement const& v) const { + return key == v.key; +} + +bool WhitespaceElement::operator==(WhitespaceElement const& v) const { + return whitespace == v.whitespace; +} + +bool ColonElement::operator==(ColonElement const&) const { + return true; +} + +bool CommaElement::operator==(CommaElement const&) const { + return true; +} + +FormattedJson FormattedJson::parse(String const& string) { + return inputUtf32Json<String::const_iterator, FormattedJsonBuilderStream, FormattedJson>( + string.begin(), string.end(), true); +} + +FormattedJson FormattedJson::parseJson(String const& string) { + return inputUtf32Json<String::const_iterator, FormattedJsonBuilderStream, FormattedJson>( + string.begin(), string.end(), false); +} + +FormattedJson FormattedJson::ofType(Json::Type type) { + FormattedJson json; + json.m_jsonValue = Json::ofType(type); + return json; +} + +FormattedJson::FormattedJson() : FormattedJson(Json()) {} + +FormattedJson::FormattedJson(Json const& json) + : m_jsonValue(Json::ofType(json.type())), + m_elements(), + m_formatting(), + m_lastKey(), + m_objectEntryLocations(), + m_arrayElementLocations() { + if (json.type() == Json::Type::Object || json.type() == Json::Type::Array) { + FormattedJsonBuilderStream stream; + JsonStreamer<Json>::toJsonStream(json, stream, false); + FormattedJson parsed = stream.takeTop(); + for (JsonElement const& elem : parsed.elements()) { + appendElement(elem); + } + } + m_jsonValue = json; +} + +Json const& FormattedJson::toJson() const { + return m_jsonValue; +} + +FormattedJson FormattedJson::get(String const& key) const { + if (type() != Json::Type::Object) + throw JsonException::format("Cannot call get with key on FormattedJson type %s, must be Object type", typeName()); + + Maybe<pair<ElementLocation, ElementLocation>> entry = m_objectEntryLocations.maybe(key); + if (entry.isNothing()) + throw JsonException::format("No such key in FormattedJson::get(\"%s\")", key); + + return getFormattedJson(entry->second); +} + +FormattedJson FormattedJson::get(size_t index) const { + if (type() != Json::Type::Array) + throw JsonException::format("Cannot call get with index on FormattedJson type %s, must be Array type", typeName()); + + if (index >= m_arrayElementLocations.size()) + throw JsonException::format("FormattedJson::get(%s) out of range", index); + + ElementLocation loc = m_arrayElementLocations.at(index); + return getFormattedJson(loc); +} + +struct WhitespaceStyle { + String beforeKey; + String beforeColon; + String beforeValue; + String beforeComma; +}; + +template <class ElementType> +FormattedJson::ElementLocation indexOf(FormattedJson::ElementList const& elements, FormattedJson::ElementLocation pos) { + for (; pos < elements.size(); ++pos) { + if (elements[pos].is<ElementType>()) + return pos; + } + return NPos; +} + +template <class ElementType> +FormattedJson::ElementLocation lastIndexOf( + FormattedJson::ElementList const& elements, FormattedJson::ElementLocation pos) { + while (pos > 0) { + --pos; + if (elements[pos].is<ElementType>()) + return pos; + } + return NPos; +} + +String concatWhitespace(FormattedJson::ElementList const& elements, FormattedJson::ElementLocation from, + FormattedJson::ElementLocation to) { + String whitespace; + for (JsonElement const& elem : elements.slice(from, to)) { + if (elem.is<WhitespaceElement>()) + whitespace += elem.get<WhitespaceElement>().whitespace; + } + return whitespace; +} + +WhitespaceStyle detectWhitespace(FormattedJson::ElementList const& elements, + FormattedJson::ElementLocation insertLoc, bool array) { + WhitespaceStyle style; + + // Find a nearby value as a reference location to learn whitespace from. + FormattedJson::ElementLocation valueLoc = lastIndexOf<ValueElement>(elements, insertLoc); + if (valueLoc == NPos) + valueLoc = indexOf<ValueElement>(elements, insertLoc); + + if (valueLoc == NPos) { + // This object/array is empty. Pre-key/value whitespace will be the total of + // the whitespace already present, plus some guessed indentation if it + // contained a newline. + String beforeValue = concatWhitespace(elements, 0, elements.size()); + if (beforeValue.find('\n') != NPos) + beforeValue += " "; + if (array) + return WhitespaceStyle{"", "", beforeValue, ""}; + return WhitespaceStyle{beforeValue, "", "", ""}; + } + + FormattedJson::ElementLocation commaLoc = indexOf<CommaElement>(elements, valueLoc); + if (commaLoc != NPos) { + style.beforeComma = concatWhitespace(elements, valueLoc + 1, commaLoc); + } + + FormattedJson::ElementLocation colonLoc = lastIndexOf<ColonElement>(elements, valueLoc); + starAssert((colonLoc == NPos) == array); + if (colonLoc != NPos) { + style.beforeValue = concatWhitespace(elements, colonLoc + 1, valueLoc); + + FormattedJson::ElementLocation keyLoc = lastIndexOf<ObjectKeyElement>(elements, colonLoc); + starAssert(keyLoc != NPos); + style.beforeColon = concatWhitespace(elements, keyLoc + 1, colonLoc); + + FormattedJson::ElementLocation prevValueLoc = lastIndexOf<ValueElement>(elements, keyLoc); + if (prevValueLoc == NPos) + prevValueLoc = 0; + style.beforeKey = concatWhitespace(elements, prevValueLoc, keyLoc); + + } else { + FormattedJson::ElementLocation prevValueLoc = lastIndexOf<ValueElement>(elements, valueLoc); + if (prevValueLoc == NPos) + prevValueLoc = 0; + style.beforeValue = concatWhitespace(elements, prevValueLoc, valueLoc); + } + + return style; +} + +void insertWhitespace(FormattedJson::ElementList& destination, FormattedJson::ElementLocation& at, String const& whitespace) { + if (whitespace == "") + return; + destination.insertAt(at++, WhitespaceElement{whitespace}); +} + +void insertWithWhitespace(FormattedJson::ElementList& destination, WhitespaceStyle const& style, + FormattedJson::ElementLocation& at, JsonElement const& element) { + if (element.is<ValueElement>()) + insertWhitespace(destination, at, style.beforeValue); + if (element.is<ObjectKeyElement>()) + insertWhitespace(destination, at, style.beforeKey); + if (element.is<ColonElement>()) + insertWhitespace(destination, at, style.beforeColon); + if (element.is<CommaElement>()) + insertWhitespace(destination, at, style.beforeComma); + destination.insertAt(at++, element); +} + +void insertWithCommaAndFormatting(FormattedJson::ElementList& destination, FormattedJson::ElementLocation at, + bool array, FormattedJson::ElementList const& elements) { + // Find the previous value we're inserting after, if any. + at = lastIndexOf<ValueElement>(destination, at); + if (at == NPos) + at = 0; + else + at += 1; + bool empty = lastIndexOf<ValueElement>(destination, destination.size()) == NPos; + bool appendComma = at == 0 && !empty; + bool prependComma = !appendComma && !empty; + + WhitespaceStyle style = detectWhitespace(destination, at, array); + + if (prependComma) { + // Inserting into a non-empty object/array. Prepend a comma + insertWithWhitespace(destination, style, at, CommaElement{}); + } + for (JsonElement const& elem : elements) { + insertWithWhitespace(destination, style, at, elem); + } + if (appendComma) { + insertWithWhitespace(destination, style, at, CommaElement{}); + } +} + +FormattedJson FormattedJson::prepend(String const& key, FormattedJson const& value) const { + return objectInsert(key, value, 0); +} + +FormattedJson FormattedJson::insertBefore(String const& key, FormattedJson const& value, String const& beforeKey) const { + if (!m_objectEntryLocations.contains(beforeKey)) + throw JsonException::format("Cannot insert before key \"%s\", which does not exist", beforeKey); + ElementLocation loc = m_objectEntryLocations.get(beforeKey).first; + return objectInsert(key, value, loc); +} + +FormattedJson FormattedJson::insertAfter(String const& key, FormattedJson const& value, String const& afterKey) const { + if (!m_objectEntryLocations.contains(afterKey)) + throw JsonException::format("Cannot insert after key \"%s\", which does not exist", afterKey); + ElementLocation loc = m_objectEntryLocations.get(afterKey).second; + return objectInsert(key, value, loc + 1); +} + +FormattedJson FormattedJson::append(String const& key, FormattedJson const& value) const { + return objectInsert(key, value, m_elements.size()); +} + +FormattedJson FormattedJson::set(String const& key, FormattedJson const& value) const { + return objectInsert(key, value, m_elements.size()); +} + +void removeValueFromArray(List<JsonElement>& elements, size_t loc) { + // Remove the value itself, the comma following and the whitespace up to the + // next value. + // If it's the last value, it removes the value, and the preceding whitespace + // and comma. + size_t commaLoc = elements.indexOf(CommaElement{}, loc); + if (commaLoc != NPos) { + elements.eraseAt(loc, commaLoc + 1); + while (loc < elements.size() && elements.at(loc).is<WhitespaceElement>()) + elements.eraseAt(loc); + } else { + commaLoc = elements.lastIndexOf(CommaElement{}, loc); + if (commaLoc == NPos) + commaLoc = 0; + elements.eraseAt(commaLoc, loc + 1); + } +} + +FormattedJson FormattedJson::eraseKey(String const& key) const { + if (type() != Json::Type::Object) + throw JsonException::format("Cannot call erase with key on FormattedJson type %s, must be Object type", typeName()); + + Maybe<pair<ElementLocation, ElementLocation>> maybeEntry = m_objectEntryLocations.maybe(key); + if (maybeEntry.isNothing()) + return *this; + + ElementLocation loc = maybeEntry->first; + ElementList elements = m_elements; + elements.eraseAt(loc, maybeEntry->second); // Remove key, colon and whitespace up to the value + removeValueFromArray(elements, loc); + return object(elements); +} + +FormattedJson FormattedJson::insert(size_t index, FormattedJson const& value) const { + if (type() != Json::Type::Array) + throw JsonException::format( + "Cannot call insert with index on FormattedJson type %s, must be Array type", typeName()); + + if (index > m_arrayElementLocations.size()) + throw JsonException::format("FormattedJson::insert(%s) out of range", index); + + ElementList elements = m_elements; + ElementLocation insertPosition = elements.size(); + if (index < m_arrayElementLocations.size()) + insertPosition = m_arrayElementLocations.at(index); + + insertWithCommaAndFormatting(elements, insertPosition, true, {ValueElement{value}}); + return array(elements); +} + +FormattedJson FormattedJson::append(FormattedJson const& value) const { + if (type() != Json::Type::Array) + throw JsonException::format("Cannot call append on FormattedJson type %s, must be Array type", typeName()); + + ElementList elements = m_elements; + insertWithCommaAndFormatting(elements, elements.size(), true, {ValueElement{value}}); + return array(elements); +} + +FormattedJson FormattedJson::set(size_t index, FormattedJson const& value) const { + if (type() != Json::Type::Array) + throw JsonException::format("Cannot call set with index on FormattedJson type %s, must be Array type", typeName()); + + if (index >= m_arrayElementLocations.size()) + throw JsonException::format("FormattedJson::set(%s) out of range", index); + + ElementLocation loc = m_arrayElementLocations.at(index); + ElementList elements = m_elements; + elements.at(loc) = ValueElement{value}; + return array(elements); +} + +FormattedJson FormattedJson::eraseIndex(size_t index) const { + if (type() != Json::Type::Array) + throw JsonException::format("Cannot call set with index on FormattedJson type %s, must be Array type", typeName()); + + if (index >= m_arrayElementLocations.size()) + throw JsonException::format("FormattedJson::eraseIndex(%s) out of range", index); + + ElementLocation loc = m_arrayElementLocations.at(index); + ElementList elements = m_elements; + removeValueFromArray(elements, loc); + return array(elements); +} + +size_t FormattedJson::size() const { + return m_jsonValue.size(); +} + +bool FormattedJson::contains(String const& key) const { + return m_jsonValue.contains(key); +} + +Json::Type FormattedJson::type() const { + return m_jsonValue.type(); +} + +bool FormattedJson::isType(Json::Type type) const { + return m_jsonValue.isType(type); +} + +String FormattedJson::typeName() const { + return m_jsonValue.typeName(); +} + +String FormattedJson::toFormattedDouble() const { + if (!isType(Json::Type::Float)) + throw JsonException::format("Cannot call toFormattedDouble on Json type %s, must be Float", typeName()); + if (m_formatting.isValid()) + return *m_formatting; + return toJson().repr(); +} + +String FormattedJson::toFormattedInt() const { + if (!isType(Json::Type::Int)) + throw JsonException::format("Cannot call toFormattedInt on Json type %s, must be Int", typeName()); + if (m_formatting.isValid()) + return *m_formatting; + return toJson().repr(); +} + +String FormattedJson::repr() const { + if (m_formatting.isValid()) + return *m_formatting; + String result; + outputUtf32Json<std::back_insert_iterator<String>, FormattedJson>(*this, std::back_inserter(result), 0, false); + return result; +} + +String FormattedJson::printJson() const { + if (type() != Json::Type::Object && type() != Json::Type::Array) + throw JsonException("printJson called on non-top-level JSON type"); + return repr(); +} + +Json elemToJson(JsonElement const& elem) { + return elem.get<ValueElement>().value->toJson(); +} + +FormattedJson::ElementList const& FormattedJson::elements() const { + return m_elements; +} + +bool FormattedJson::operator==(FormattedJson const& v) const { + return m_jsonValue == v.m_jsonValue; +} + +bool FormattedJson::operator!=(FormattedJson const& v) const { + return !(*this == v); +} + +FormattedJson FormattedJson::object(ElementList const& elements) { + FormattedJson json = ofType(Json::Type::Object); + for (JsonElement const& elem : elements) { + json.appendElement(elem); + } + return json; +} + +FormattedJson FormattedJson::array(ElementList const& elements) { + FormattedJson json = ofType(Json::Type::Array); + for (JsonElement const& elem : elements) { + if (elem.is<ColonElement>() || elem.is<ObjectKeyElement>()) + throw JsonException("Invalid FormattedJson element in Json array"); + json.appendElement(elem); + } + return json; +} + +FormattedJson FormattedJson::objectInsert(String const& key, FormattedJson const& value, ElementLocation loc) const { + if (type() != Json::Type::Object) + throw JsonException::format("Cannot call set with key on FormattedJson type %s, must be Object type", typeName()); + + Maybe<pair<ElementLocation, ElementLocation>> maybeEntry = m_objectEntryLocations.maybe(key); + if (maybeEntry.isValid()) { + ElementList elements = m_elements; + elements.at(maybeEntry->second) = ValueElement{value}; + return object(elements); + } + + ElementList elements = m_elements; + insertWithCommaAndFormatting(elements, loc, false, {ObjectKeyElement{key}, ColonElement{}, ValueElement{value}}); + return object(elements); +} + +void FormattedJson::appendElement(JsonElement const& elem) { + ElementLocation loc = m_elements.size(); + m_elements.append(elem); + + if (elem.is<ObjectKeyElement>()) { + starAssert(isType(Json::Type::Object)); + m_lastKey = loc; + + } else if (elem.is<ValueElement>()) { + m_lastValue = loc; + + if (m_lastKey.isValid()) { + starAssert(isType(Json::Type::Object)); + String key = m_elements[*m_lastKey].get<ObjectKeyElement>().key; + + m_objectEntryLocations[key] = make_pair(*m_lastKey, loc); + m_jsonValue = m_jsonValue.set(key, elemToJson(elem)); + + m_lastKey = {}; + } else { + starAssert(isType(Json::Type::Array)); + m_arrayElementLocations.append(loc); + + m_jsonValue = m_jsonValue.append(elemToJson(elem)); + } + } +} + +FormattedJson const& FormattedJson::getFormattedJson(ElementLocation loc) const { + return *m_elements[loc].get<ValueElement>().value; +} + +FormattedJson FormattedJson::formattedAs(String const& formatting) const { + starAssert(Json::parse(formatting) == toJson()); + FormattedJson json = *this; + json.m_formatting = formatting; + return json; +} + +void FormattedJsonBuilderStream::beginObject() { + FormattedJson value = FormattedJson::ofType(Json::Type::Object); + push(value); +} + +void FormattedJsonBuilderStream::objectKey(String::Char const* s, size_t len) { + current().appendElement(ObjectKeyElement{String(s, len)}); +} + +void FormattedJsonBuilderStream::endObject() { + FormattedJson value = pop(); + if (m_stack.size() > 0) + current().appendElement(ValueElement{value}); + else + m_root = value; +} + +void FormattedJsonBuilderStream::beginArray() { + FormattedJson value = FormattedJson::ofType(Json::Type::Array); + push(value); +} + +void FormattedJsonBuilderStream::endArray() { + FormattedJson value = pop(); + if (m_stack.size() > 0) + current().appendElement(ValueElement{value}); + else + m_root = value; +} + +void FormattedJsonBuilderStream::putString(String::Char const* s, size_t len) { + putValue(String(s, len)); +} + +void FormattedJsonBuilderStream::putDouble(String::Char const* s, size_t len) { + String formatted(s, len); + double d = lexicalCast<double>(formatted); + putValue(d, formatted); +} + +void FormattedJsonBuilderStream::putInteger(String::Char const* s, size_t len) { + String formatted(s, len); + long long d = lexicalCast<long long>(formatted); + putValue(d, formatted); +} + +void FormattedJsonBuilderStream::putBoolean(bool b) { + putValue(b); +} + +void FormattedJsonBuilderStream::putNull() { + putValue(Json::ofType(Json::Type::Null)); +} + +void FormattedJsonBuilderStream::putWhitespace(String::Char const* s, size_t len) { + if (m_stack.size() > 0) + current().appendElement(WhitespaceElement{String(s, len)}); +} + +void FormattedJsonBuilderStream::putColon() { + current().appendElement(ColonElement{}); +} + +void FormattedJsonBuilderStream::putComma() { + current().appendElement(CommaElement{}); +} + +FormattedJson FormattedJsonBuilderStream::takeTop() { + return m_root.take(); +} + +void FormattedJsonBuilderStream::push(FormattedJson const& v) { + m_stack.push_back(v); +} + +FormattedJson FormattedJsonBuilderStream::pop() { + FormattedJson result = m_stack.back(); + m_stack.pop_back(); + return result; +} + +FormattedJson& FormattedJsonBuilderStream::current() { + return m_stack.back(); +} + +void FormattedJsonBuilderStream::putValue(Json const& value, Maybe<String> formatting) { + FormattedJson formattedValue = value; + if (formatting.isValid()) + formattedValue = formattedValue.formattedAs(*formatting); + + if (m_stack.size() > 0) + current().appendElement(ValueElement{formattedValue}); + else { + m_root = formattedValue; + } +} + +void JsonStreamer<FormattedJson>::toJsonStream(FormattedJson const& val, JsonStream& stream, bool sort) { + if (val.isType(Json::Type::Object)) + stream.beginObject(); + + else if (val.isType(Json::Type::Array)) + stream.beginArray(); + + else if (val.isType(Json::Type::Float)) { + // Float and Int are to be formatted the same way they were parsed to + // preserve, e.g. negative zeroes and trailing 0 digits on decimals. + auto ws = val.toFormattedDouble().wideString(); + stream.putDouble(ws.c_str(), ws.length()); + return; + + } else if (val.isType(Json::Type::Int)) { + auto ws = val.toFormattedInt().wideString(); + stream.putInteger(ws.c_str(), ws.length()); + return; + + } else { + // If val is not an object, array or number, it has no formatting and no + // elements. Stream the wrapped Json value the usual way. + JsonStreamer<Json>::toJsonStream(val.toJson(), stream, sort); + return; + } + + for (JsonElement elem : val.elements()) { + if (elem.is<ObjectKeyElement>()) { + String::WideString key = elem.get<ObjectKeyElement>().key.wideString(); + stream.objectKey(key.c_str(), key.length()); + } else if (elem.is<WhitespaceElement>()) { + String::WideString white = elem.get<WhitespaceElement>().whitespace.wideString(); + stream.putWhitespace(white.c_str(), white.length()); + } else if (elem.is<ColonElement>()) { + stream.putColon(); + } else if (elem.is<CommaElement>()) { + stream.putComma(); + } else { + toJsonStream(*elem.get<ValueElement>().value, stream, sort); + } + } + + if (val.isType(Json::Type::Object)) + stream.endObject(); + if (val.isType(Json::Type::Array)) + stream.endArray(); +} + +std::ostream& operator<<(std::ostream& os, JsonElement const& elem) { + if (elem.is<ValueElement>()) + return os << "ValueElement{" << elem.get<ValueElement>().value << "}"; + if (elem.is<ObjectKeyElement>()) + return os << "ObjectKeyElement{" << elem.get<ObjectKeyElement>().key << "}"; + if (elem.is<WhitespaceElement>()) + return os << "WhitespaceElement{" << elem.get<WhitespaceElement>().whitespace << "}"; + if (elem.is<ColonElement>()) + return os << "ColonElement{}"; + if (elem.is<CommaElement>()) + return os << "CommaElement{}"; + starAssert(false); + return os; +} + +std::ostream& operator<<(std::ostream& os, FormattedJson const& json) { + return os << json.repr(); +} + +} diff --git a/source/core/StarFormattedJson.hpp b/source/core/StarFormattedJson.hpp new file mode 100644 index 0000000..555d436 --- /dev/null +++ b/source/core/StarFormattedJson.hpp @@ -0,0 +1,133 @@ +#ifndef STAR_JSON_COMMENTS_HPP +#define STAR_JSON_COMMENTS_HPP + +#include <list> + +#include "StarJson.hpp" + +namespace Star { + +STAR_CLASS(FormattedJson); + +struct ObjectElement; +struct ObjectKeyElement; +struct ValueElement; +struct WhitespaceElement; +struct ColonElement; +struct CommaElement; + +typedef Variant<ValueElement, ObjectKeyElement, WhitespaceElement, ColonElement, CommaElement> JsonElement; + +struct ValueElement { + ValueElement(FormattedJson const& json); + + FormattedJsonPtr value; + + bool operator==(ValueElement const& v) const; +}; + +struct ObjectKeyElement { + String key; + + bool operator==(ObjectKeyElement const& v) const; +}; + +struct WhitespaceElement { + String whitespace; + + bool operator==(WhitespaceElement const& v) const; +}; + +struct ColonElement { + bool operator==(ColonElement const&) const; +}; + +struct CommaElement { + bool operator==(CommaElement const&) const; +}; + +std::ostream& operator<<(std::ostream& os, JsonElement const& elem); + +// Class representing formatted JSON data. Preserves whitespace and comments. +class FormattedJson { +public: + typedef List<JsonElement> ElementList; + typedef size_t ElementLocation; + + static FormattedJson parse(String const& string); + static FormattedJson parseJson(String const& string); + + static FormattedJson ofType(Json::Type type); + + FormattedJson(); + FormattedJson(Json const&); + + Json const& toJson() const; + + FormattedJson get(String const& key) const; + FormattedJson get(size_t index) const; + + // Returns a new FormattedJson with the given values added or erased. + // Prepend, insert and append update the value in-place if the key already + // exists. + FormattedJson prepend(String const& key, FormattedJson const& value) const; + FormattedJson insertBefore(String const& key, FormattedJson const& value, String const& beforeKey) const; + FormattedJson insertAfter(String const& key, FormattedJson const& value, String const& afterKey) const; + FormattedJson append(String const& key, FormattedJson const& value) const; + FormattedJson set(String const& key, FormattedJson const& value) const; + FormattedJson eraseKey(String const& key) const; + + FormattedJson insert(size_t index, FormattedJson const& value) const; + FormattedJson append(FormattedJson const& value) const; + FormattedJson set(size_t index, FormattedJson const& value) const; + FormattedJson eraseIndex(size_t index) const; + + // Returns the number of elements in a Json array, or entries in an object. + size_t size() const; + + bool contains(String const& key) const; + + Json::Type type() const; + bool isType(Json::Type type) const; + String typeName() const; + + String toFormattedDouble() const; + String toFormattedInt() const; + + String repr() const; + String printJson() const; + + ElementList const& elements() const; + + // Equality ignores whitespace and formatting. It just compares the Json + // values. + bool operator==(FormattedJson const& v) const; + bool operator!=(FormattedJson const& v) const; + +private: + friend class FormattedJsonBuilderStream; + + static FormattedJson object(ElementList const& elements); + static FormattedJson array(ElementList const& elements); + + FormattedJson objectInsert(String const& key, FormattedJson const& value, ElementLocation loc) const; + void appendElement(JsonElement const& elem); + + FormattedJson const& getFormattedJson(ElementLocation loc) const; + FormattedJson formattedAs(String const& formatting) const; + + Json m_jsonValue; + ElementList m_elements; + // Used to preserve the formatting of numbers, i.e. -0 vs 0, 1.0 vs 1: + Maybe<String> m_formatting; + + Maybe<ElementLocation> m_lastKey, m_lastValue; + Map<String, pair<ElementLocation, ElementLocation>> m_objectEntryLocations; + List<ElementLocation> m_arrayElementLocations; +}; + +std::ostream& operator<<(std::ostream& os, FormattedJson const& json); + +} + +#endif diff --git a/source/core/StarHash.hpp b/source/core/StarHash.hpp new file mode 100644 index 0000000..0e3f67c --- /dev/null +++ b/source/core/StarHash.hpp @@ -0,0 +1,100 @@ +#ifndef STAR_HASH_HPP +#define STAR_HASH_HPP + +#include "StarBytes.hpp" + +namespace Star { + +// To avoid having to specialize std::hash in the std namespace, which is +// slightly annoying, Star type wrappers use Star::hash, which just defaults to +// std::hash. Star::hash also enables template specialization with a dummy +// Enable parameter. +template <typename T, typename Enable = void> +struct hash : public std::hash<T> {}; + +inline void hashCombine(size_t& hash, size_t comb) { + hash ^= comb * 2654435761 + 0x9e3779b9 + (hash << 6) + (hash >> 2); +} + +// Paul Larson hashing algorithm, very very *cheap* hashing function. +class PLHasher { +public: + PLHasher(size_t initial = 0) + : m_hash(initial) {} + + template <typename T> + void put(T b) { + m_hash = m_hash * 101 + (size_t)b; + } + + size_t hash() const { + return m_hash; + } + +private: + size_t m_hash; +}; + +template <typename first_t, typename second_t> +class hash<std::pair<first_t, second_t>> { +private: + Star::hash<first_t> firstHasher; + Star::hash<second_t> secondHasher; + +public: + size_t operator()(std::pair<first_t, second_t> const& a) const { + size_t hashval = firstHasher(a.first); + hashCombine(hashval, secondHasher(a.second)); + return hashval; + } +}; + +template <typename... TTypes> +class hash<std::tuple<TTypes...>> { +private: + typedef std::tuple<TTypes...> Tuple; + + template <size_t N> + size_t operator()(Tuple const&) const { + return 0; + } + + template <size_t N, typename THead, typename... TTail> + size_t operator()(Tuple const& value) const { + size_t hash = Star::hash<THead>()(std::get<N - sizeof...(TTail) - 1>(value)); + hashCombine(hash, operator()<N, TTail...>(value)); + return hash; + } + +public: + size_t operator()(Tuple const& value) const { + return operator()<sizeof...(TTypes), TTypes...>(value); + } +}; + +template <typename EnumType> +class hash<EnumType, typename std::enable_if<std::is_enum<EnumType>::value>::type> { +private: + typedef typename std::underlying_type<EnumType>::type UnderlyingType; + +public: + size_t operator()(EnumType e) const { + return std::hash<UnderlyingType>()((UnderlyingType)e); + } +}; + +template <typename T> +size_t hashOf(T const& t) { + return Star::hash<T>()(t); +} + +template <typename T1, typename T2, typename... TL> +size_t hashOf(T1 const& t1, T2 const& t2, TL const&... rest) { + size_t hash = hashOf(t1); + hashCombine(hash, hashOf(t2, rest...)); + return hash; +}; + +} + +#endif diff --git a/source/core/StarHostAddress.cpp b/source/core/StarHostAddress.cpp new file mode 100644 index 0000000..f65f179 --- /dev/null +++ b/source/core/StarHostAddress.cpp @@ -0,0 +1,281 @@ +#include "StarHostAddress.hpp" +#include "StarLexicalCast.hpp" +#include "StarNetImpl.hpp" + +namespace Star { + +HostAddress HostAddress::localhost(NetworkMode mode) { + if (mode == NetworkMode::IPv4) { + uint8_t addr[4] = {127, 0, 0, 1}; + return HostAddress(mode, addr); + } else if (mode == NetworkMode::IPv6) { + uint8_t addr[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + return HostAddress(mode, addr); + } + + return HostAddress(); +} + +Either<String, HostAddress> HostAddress::lookup(String const& address) { + try { + HostAddress ha; + ha.set(address); + return makeRight(move(ha)); + } catch (NetworkException const& e) { + return makeLeft(String(e.what())); + } +} + +HostAddress::HostAddress(NetworkMode mode, uint8_t* address) { + set(mode, address); +} + +HostAddress::HostAddress(String const& address) { + auto a = lookup(address); + if (a.isLeft()) + throw NetworkException(a.left().takeUtf8()); + else + *this = move(a.right()); +} + +NetworkMode HostAddress::mode() const { + return m_mode; +} + +uint8_t const* HostAddress::bytes() const { + return m_address; +} + +uint8_t HostAddress::octet(size_t i) const { + return m_address[i]; +} + +bool HostAddress::isLocalHost() const { + if (m_mode == NetworkMode::IPv4) { + return (m_address[0] == 127 && m_address[1] == 0 && m_address[2] == 0 && m_address[3] == 1); + + } else { + for (size_t i = 0; i < 15; ++i) { + if (m_address[i] != 0) + return false; + } + + return m_address[15] == 1; + } +} + +bool HostAddress::isZero() const { + if (mode() == NetworkMode::IPv4) + return m_address[0] == 0 && m_address[1] == 0 && m_address[2] == 0 && m_address[3] == 0; + + if (mode() == NetworkMode::IPv6) { + for (size_t i = 0; i < 16; i++) { + if (m_address[i] != 0) + return false; + } + return true; + } + + return false; +} + +size_t HostAddress::size() const { + switch (m_mode) { + case NetworkMode::IPv4: + return 4; + case NetworkMode::IPv6: + return 16; + default: + return 0; + } +} + +bool HostAddress::operator==(HostAddress const& a) const { + if (m_mode != a.m_mode) + return false; + + size_t len = a.size(); + for (size_t i = 0; i < len; i++) { + if (m_address[i] != a.m_address[i]) + return false; + } + + return true; +} + +void HostAddress::set(String const& address) { + if (address.empty()) + return; + + if (address.compare("*") == 0 || address.compare("0.0.0.0") == 0) { + uint8_t inaddr_any[4]; + memset(inaddr_any, 0, sizeof(inaddr_any)); + set(NetworkMode::IPv4, inaddr_any); + } else if (address.compare("::") == 0) { + // NOTE: This will likely bind to both IPv6 and IPv4, but it does depending + // on the OS settings + uint8_t inaddr_any[16]; + memset(inaddr_any, 0, sizeof(inaddr_any)); + set(NetworkMode::IPv6, inaddr_any); + } else { + struct addrinfo* result = NULL; + struct addrinfo* ptr = NULL; + struct addrinfo hints; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + // Eliminate duplicates being returned one for each socket type. + // As we're not using the return socket type or protocol this doesn't effect + // us. + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + // Request only usable addresses e.g. IPv6 only if IPv6 is available + hints.ai_flags = AI_ADDRCONFIG; + + if (::getaddrinfo(address.utf8Ptr(), NULL, &hints, &result) != 0) + throw NetworkException(strf("Failed to determine address for '%s' (%s)", address, netErrorString())); + + for (ptr = result; ptr != NULL; ptr = ptr->ai_next) { + NetworkMode mode; + switch (ptr->ai_family) { + case AF_INET: + mode = NetworkMode::IPv4; + break; + case AF_INET6: + mode = NetworkMode::IPv6; + break; + default: + continue; + } + if (mode == NetworkMode::IPv4) { + struct sockaddr_in* info = (struct sockaddr_in*)ptr->ai_addr; + set(mode, (uint8_t*)(&info->sin_addr)); + } else { + struct sockaddr_in6* info = (struct sockaddr_in6*)ptr->ai_addr; + set(mode, (uint8_t*)(&info->sin6_addr)); + } + break; + } + freeaddrinfo(result); + } +} + +void HostAddress::set(NetworkMode mode, uint8_t const* addr) { + m_mode = mode; + if (addr) + memcpy(m_address, addr, size()); + else + memset(m_address, 0, 16); +} + +std::ostream& operator<<(std::ostream& os, HostAddress const& address) { + switch (address.mode()) { + case NetworkMode::IPv4: + format(os, "%d.%d.%d.%d", address.octet(0), address.octet(1), address.octet(2), address.octet(3)); + break; + + case NetworkMode::IPv6: + format(os, + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x", + address.octet(0), + address.octet(1), + address.octet(2), + address.octet(3), + address.octet(4), + address.octet(5), + address.octet(6), + address.octet(7), + address.octet(8), + address.octet(9), + address.octet(10), + address.octet(11), + address.octet(12), + address.octet(13), + address.octet(14), + address.octet(15)); + break; + + default: + throw NetworkException(strf("Unknown address mode (%d)", (int)address.mode())); + } + return os; +} + +size_t hash<HostAddress>::operator()(HostAddress const& address) const { + PLHasher hash; + for (size_t i = 0; i < address.size(); ++i) + hash.put(address.octet(i)); + return hash.hash(); +} + +Either<String, HostAddressWithPort> HostAddressWithPort::lookup(String const& address, uint16_t port) { + auto hostAddress = HostAddress::lookup(address); + if (hostAddress.isLeft()) + return makeLeft(move(hostAddress.left())); + else + return makeRight(HostAddressWithPort(move(hostAddress.right()), port)); +} + +Either<String, HostAddressWithPort> HostAddressWithPort::lookupWithPort(String const& address) { + String host = address; + String port = host.rextract(":"); + if (host.beginsWith("[") && host.endsWith("]")) + host = host.substr(1, host.size() - 2); + + auto portNum = maybeLexicalCast<uint16_t>(port); + if (!portNum) + return makeLeft(strf("Could not parse port portion of HostAddressWithPort '%s'", port)); + + auto hostAddress = HostAddress::lookup(host); + if (hostAddress.isLeft()) + return makeLeft(move(hostAddress.left())); + + return makeRight(HostAddressWithPort(move(hostAddress.right()), *portNum)); +} + +HostAddressWithPort::HostAddressWithPort() : m_port(0) {} + +HostAddressWithPort::HostAddressWithPort(HostAddress const& address, uint16_t port) + : m_address(address), m_port(port) {} + +HostAddressWithPort::HostAddressWithPort(NetworkMode mode, uint8_t* address, uint16_t port) { + m_address = HostAddress(mode, address); + m_port = port; +} + +HostAddressWithPort::HostAddressWithPort(String const& address, uint16_t port) { + auto a = lookup(address, port); + if (a.isLeft()) + throw NetworkException(a.left().takeUtf8()); + *this = move(a.right()); +} + +HostAddressWithPort::HostAddressWithPort(String const& address) { + auto a = lookupWithPort(address); + if (a.isLeft()) + throw NetworkException(a.left().takeUtf8()); + *this = move(a.right()); +} + +HostAddress HostAddressWithPort::address() const { + return m_address; +} + +uint16_t HostAddressWithPort::port() const { + return m_port; +} + +bool HostAddressWithPort::operator==(HostAddressWithPort const& rhs) const { + return tie(m_address, m_port) == tie(rhs.m_address, rhs.m_port); +} + +std::ostream& operator<<(std::ostream& os, HostAddressWithPort const& addressWithPort) { + os << addressWithPort.address() << ":" << addressWithPort.port(); + return os; +} + +size_t hash<HostAddressWithPort>::operator()(HostAddressWithPort const& addressWithPort) const { + return hashOf(addressWithPort.address(), addressWithPort.port()); +} + +} diff --git a/source/core/StarHostAddress.hpp b/source/core/StarHostAddress.hpp new file mode 100644 index 0000000..2f5a128 --- /dev/null +++ b/source/core/StarHostAddress.hpp @@ -0,0 +1,89 @@ +#ifndef STAR_HOST_ADDRESS_HPP +#define STAR_HOST_ADDRESS_HPP + +#include "StarString.hpp" +#include "StarEither.hpp" + +namespace Star { + +STAR_EXCEPTION(NetworkException, IOException); + +STAR_CLASS(HostAddress); + +enum class NetworkMode { + IPv4, + IPv6 +}; + +class HostAddress { +public: + static HostAddress localhost(NetworkMode mode = NetworkMode::IPv4); + + // Returns either error or valid HostAddress + static Either<String, HostAddress> lookup(String const& address); + + // If address is nullptr, constructs the zero address. + HostAddress(NetworkMode mode = NetworkMode::IPv4, uint8_t* address = nullptr); + // Throws if address is not valid + explicit HostAddress(String const& address); + + NetworkMode mode() const; + uint8_t const* bytes() const; + uint8_t octet(size_t i) const; + size_t size() const; + + bool isLocalHost() const; + bool isZero() const; + + bool operator==(HostAddress const& a) const; + +private: + void set(String const& address); + void set(NetworkMode mode, uint8_t const* addr); + + NetworkMode m_mode; + uint8_t m_address[16]; +}; + +std::ostream& operator<<(std::ostream& os, HostAddress const& address); + +template <> +struct hash<HostAddress> { + size_t operator()(HostAddress const& address) const; +}; + +class HostAddressWithPort { +public: + // Returns either error or valid HostAddressWithPort + static Either<String, HostAddressWithPort> lookup(String const& address, uint16_t port); + // Format may have [] brackets around address or not, to distinguish address + // portion from port portion. + static Either<String, HostAddressWithPort> lookupWithPort(String const& address); + + HostAddressWithPort(); + HostAddressWithPort(HostAddress const& address, uint16_t port); + HostAddressWithPort(NetworkMode mode, uint8_t* address, uint16_t port); + // Throws if address or port is not valid + HostAddressWithPort(String const& address, uint16_t port); + explicit HostAddressWithPort(String const& address); + + HostAddress address() const; + uint16_t port() const; + + bool operator==(HostAddressWithPort const& a) const; + +private: + HostAddress m_address; + uint16_t m_port; +}; + +std::ostream& operator<<(std::ostream& os, HostAddressWithPort const& address); + +template <> +struct hash<HostAddressWithPort> { + size_t operator()(HostAddressWithPort const& address) const; +}; + +} + +#endif diff --git a/source/core/StarIODevice.cpp b/source/core/StarIODevice.cpp new file mode 100644 index 0000000..d87d33f --- /dev/null +++ b/source/core/StarIODevice.cpp @@ -0,0 +1,153 @@ +#include "StarIODevice.hpp" +#include "StarMathCommon.hpp" +#include "StarFormat.hpp" + +namespace Star { + +IODevice::IODevice(IOMode mode) { + m_mode = mode; +} + +IODevice::~IODevice() { + close(); +} + +void IODevice::resize(StreamOffset) { + throw IOException("resize not supported"); +} + +void IODevice::readFull(char* data, size_t len) { + size_t r = read(data, len); + if (r < len) { + if (atEnd()) + throw EofException("Failed to read full buffer in readFull, eof reached."); + else + throw IOException("Failed to read full buffer in readFull"); + } + data += r; + len -= r; +} + +void IODevice::writeFull(char const* data, size_t len) { + size_t r = write(data, len); + if (r < len) { + if (atEnd()) + throw EofException("Failed to write full buffer in writeFull, eof reached."); + else + throw IOException("Failed to write full buffer in writeFull"); + } + data += r; + len -= r; +} + +void IODevice::open(IOMode mode) { + if (mode != m_mode) + throw IOException::format("Cannot reopen device '%s", deviceName()); +} + +void IODevice::close() { + m_mode = IOMode::Closed; +} + +void IODevice::sync() {} + +String IODevice::deviceName() const { + return strf("IODevice <%s>", this); +} + +bool IODevice::atEnd() { + return pos() >= size(); +} + +StreamOffset IODevice::size() { + try { + StreamOffset storedPos = pos(); + seek(0, IOSeek::End); + StreamOffset size = pos(); + seek(storedPos); + return size; + } catch (IOException const& e) { + throw IOException("Cannot call size() on IODevice", e); + } +} + +size_t IODevice::readAbsolute(StreamOffset readPosition, char* data, size_t len) { + StreamOffset storedPos = pos(); + seek(readPosition); + size_t ret = read(data, len); + seek(storedPos); + return ret; +} + +size_t IODevice::writeAbsolute(StreamOffset writePosition, char const* data, size_t len) { + StreamOffset storedPos = pos(); + seek(writePosition); + size_t ret = write(data, len); + seek(storedPos); + return ret; +} + +void IODevice::readFullAbsolute(StreamOffset readPosition, char* data, size_t len) { + while (len > 0) { + size_t r = readAbsolute(readPosition, data, len); + if (r == 0) + throw IOException("Failed to read full buffer in readFullAbsolute"); + readPosition += r; + data += r; + len -= r; + } +} + +void IODevice::writeFullAbsolute(StreamOffset writePosition, char const* data, size_t len) { + while (len > 0) { + size_t r = writeAbsolute(writePosition, data, len); + if (r == 0) + throw IOException("Failed to write full buffer in writeFullAbsolute"); + writePosition += r; + data += r; + len -= r; + } +} + +ByteArray IODevice::readBytes(size_t size) { + if (!size) + return {}; + + ByteArray p; + p.resize(size); + readFull(p.ptr(), size); + return p; +} + +void IODevice::writeBytes(ByteArray const& p) { + writeFull(p.ptr(), p.size()); +} + +ByteArray IODevice::readBytesAbsolute(StreamOffset readPosition, size_t size) { + if (!size) + return {}; + + ByteArray p; + p.resize(size); + readFullAbsolute(readPosition, p.ptr(), size); + return p; +} + +void IODevice::writeBytesAbsolute(StreamOffset writePosition, ByteArray const& p) { + writeFullAbsolute(writePosition, p.ptr(), p.size()); +} + +void IODevice::setMode(IOMode m) { + m_mode = m; +} + +IODevice::IODevice(IODevice const& rhs) { + m_mode = rhs.mode(); +} + +IODevice& IODevice::operator=(IODevice const& rhs) { + m_mode = rhs.mode(); + return *this; +} + +} diff --git a/source/core/StarIODevice.hpp b/source/core/StarIODevice.hpp new file mode 100644 index 0000000..834b1d4 --- /dev/null +++ b/source/core/StarIODevice.hpp @@ -0,0 +1,133 @@ +#ifndef STAR_IO_DEVICE_H +#define STAR_IO_DEVICE_H + +#include "StarByteArray.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_CLASS(IODevice); + +STAR_EXCEPTION(EofException, IOException); + +enum class IOMode : uint8_t { + Closed = 0x0, + Read = 0x1, + Write = 0x2, + ReadWrite = 0x3, + Append = 0x4, + Truncate = 0x8, +}; + +IOMode operator|(IOMode a, IOMode b); +bool operator&(IOMode a, IOMode b); + +// Should match SEEK_SET, SEEK_CUR, AND SEEK_END +enum IOSeek : uint8_t { + Absolute = 0, + Relative = 1, + End = 2 +}; + +// Abstract Interface to a random access I/O device. +class IODevice { +public: + IODevice(IOMode mode = IOMode::Closed); + virtual ~IODevice(); + + // Do a read or write that may result in less data read or written than + // requested. + virtual size_t read(char* data, size_t len) = 0; + virtual size_t write(char const* data, size_t len) = 0; + + virtual StreamOffset pos() = 0; + virtual void seek(StreamOffset pos, IOSeek mode = IOSeek::Absolute) = 0; + + // Default implementation throws unsupported exception. + virtual void resize(StreamOffset size); + + // Read / write from an absolute offset in the file without modifying the + // current file position. Default implementation stores the file position, + // then seeks and calls read/write partial, then restores the file position, + // and is not thread safe. + virtual size_t readAbsolute(StreamOffset readPosition, char* data, size_t len); + virtual size_t writeAbsolute(StreamOffset writePosition, char const* data, size_t len); + + // Read and write fully, and throw an exception in every other case. The + // default implementations here will call the normal read or write, and if + // the full amount is not read will throw an exception. + virtual void readFull(char* data, size_t len); + virtual void writeFull(char const* data, size_t len); + virtual void readFullAbsolute(StreamOffset readPosition, char* data, size_t len); + virtual void writeFullAbsolute(StreamOffset writePosition, char const* data, size_t len); + + // Default implementation throws exception if opening in a different mode + // than the current mode. + virtual void open(IOMode mode); + + // Default implementation sets mode equal to Closed + virtual void close(); + + // Default implementation is a no-op + virtual void sync(); + + // Default implementation just prints address of generic IODevice + virtual String deviceName() const; + + // Is the file position at the end of the file and there is no more to read? + // This is not the same as feof, which returns true after an unsuccesful read + // past the end, it should return true after succesfully reading the final + // byte. Default implementation returns pos() >= size(); + virtual bool atEnd(); + + // Default is to store position, seek end, then restore position. + virtual StreamOffset size(); + + IOMode mode() const; + bool isOpen() const; + bool isReadable() const; + bool isWritable() const; + + ByteArray readBytes(size_t size); + void writeBytes(ByteArray const& p); + + ByteArray readBytesAbsolute(StreamOffset readPosition, size_t size); + void writeBytesAbsolute(StreamOffset writePosition, ByteArray const& p); + +protected: + void setMode(IOMode mode); + + IODevice(IODevice const&); + IODevice& operator=(IODevice const&); + +private: + atomic<IOMode> m_mode; +}; + +inline IOMode operator|(IOMode a, IOMode b) { + return (IOMode)((uint8_t)a | (uint8_t)b); +} + +inline bool operator&(IOMode a, IOMode b) { + return (uint8_t)a & (uint8_t)b; +} + +inline IOMode IODevice::mode() const { + return m_mode; +} + +inline bool IODevice::isOpen() const { + return m_mode != IOMode::Closed; +} + +inline bool IODevice::isReadable() const { + return m_mode & IOMode::Read; +} + +inline bool IODevice::isWritable() const { + return m_mode & IOMode::Write; +} + +} + +#endif diff --git a/source/core/StarIdMap.hpp b/source/core/StarIdMap.hpp new file mode 100644 index 0000000..15946da --- /dev/null +++ b/source/core/StarIdMap.hpp @@ -0,0 +1,151 @@ +#ifndef STAR_ID_MAP_HPP +#define STAR_ID_MAP_HPP + +#include "StarMap.hpp" +#include "StarMathCommon.hpp" +#include "StarDataStream.hpp" + +namespace Star { + +STAR_EXCEPTION(IdMapException, StarException); + +// Maps key ids to values with auto generated ids in a given id range. Tries +// to cycle through ids as new values are added and avoid re-using ids until +// the id space wraps around. +template <typename BaseMap> +class IdMapWrapper : private BaseMap { +public: + typedef typename BaseMap::iterator iterator; + typedef typename BaseMap::const_iterator const_iterator; + typedef typename BaseMap::key_type key_type; + typedef typename BaseMap::value_type value_type; + typedef typename BaseMap::mapped_type mapped_type; + + typedef key_type IdType; + typedef value_type ValueType; + typedef mapped_type MappedType; + + IdMapWrapper(); + IdMapWrapper(IdType min, IdType max); + + // New valid id that does not exist in this map. Tries not to immediately + // recycle ids, to avoid temporally close id repeats. + IdType nextId(); + + // Throws exception if key already exists + void add(IdType id, MappedType mappedType); + + // Add with automatically allocated id + IdType add(MappedType mappedType); + + void clear(); + + bool operator==(IdMapWrapper const& rhs) const; + bool operator!=(IdMapWrapper const& rhs) const; + + using BaseMap::keys; + using BaseMap::values; + using BaseMap::pairs; + using BaseMap::contains; + using BaseMap::size; + using BaseMap::empty; + using BaseMap::get; + using BaseMap::ptr; + using BaseMap::maybe; + using BaseMap::take; + using BaseMap::maybeTake; + using BaseMap::remove; + using BaseMap::value; + using BaseMap::begin; + using BaseMap::end; + using BaseMap::erase; + + template <typename Base> + friend DataStream& operator>>(DataStream& ds, IdMapWrapper<Base>& map); + template <typename Base> + friend DataStream& operator<<(DataStream& ds, IdMapWrapper<Base> const& map); + +private: + IdType m_min; + IdType m_max; + IdType m_nextId; +}; + +template <class Key, class Value> +using IdMap = IdMapWrapper<Map<Key, Value>>; + +template <class Key, class Value> +using IdHashMap = IdMapWrapper<HashMap<Key, Value>>; + +template <typename BaseMap> +IdMapWrapper<BaseMap>::IdMapWrapper() + : m_min(lowest<IdType>()), m_max(highest<IdType>()), m_nextId(m_min) {} + +template <typename BaseMap> +IdMapWrapper<BaseMap>::IdMapWrapper(IdType min, IdType max) + : m_min(min), m_max(max), m_nextId(m_min) { + starAssert(m_max > m_min); +} + +template <typename BaseMap> +auto IdMapWrapper<BaseMap>::nextId() -> IdType { + if ((IdType)BaseMap::size() > m_max - m_min) + throw IdMapException("No id space left in IdMapWrapper"); + + IdType nextId = m_nextId; + while (BaseMap::contains(nextId)) + nextId = cycleIncrement(nextId, m_min, m_max); + m_nextId = cycleIncrement(nextId, m_min, m_max); + return nextId; +} + +template <typename BaseMap> +void IdMapWrapper<BaseMap>::add(IdType id, MappedType mappedType) { + if (!BaseMap::insert(make_pair(move(id), move(mappedType))).second) + throw IdMapException::format("IdMapWrapper::add(id, value) called with pre-existing id '%s'", outputAny(id)); +} + +template <typename BaseMap> +auto IdMapWrapper<BaseMap>::add(MappedType mappedType) -> IdType { + auto id = nextId(); + BaseMap::insert(id, mappedType); + return id; +} + +template <typename BaseMap> +void IdMapWrapper<BaseMap>::clear() { + BaseMap::clear(); + m_nextId = m_min; +} + +template <typename BaseMap> +bool IdMapWrapper<BaseMap>::operator==(IdMapWrapper const& rhs) const { + return tie(m_min, m_max) == tie(rhs.m_min, rhs.m_max) && BaseMap::operator==(rhs); +} + +template <typename BaseMap> +bool IdMapWrapper<BaseMap>::operator!=(IdMapWrapper const& rhs) const { + return !operator==(rhs); +} + +template <typename BaseMap> +DataStream& operator>>(DataStream& ds, IdMapWrapper<BaseMap>& map) { + ds.readMapContainer((BaseMap&)map); + ds.read(map.m_min); + ds.read(map.m_max); + ds.read(map.m_nextId); + return ds; +} + +template <typename BaseMap> +DataStream& operator<<(DataStream& ds, IdMapWrapper<BaseMap> const& map) { + ds.writeMapContainer((BaseMap const&)map); + ds.write(map.m_min); + ds.write(map.m_max); + ds.write(map.m_nextId); + return ds; +} + +} + +#endif diff --git a/source/core/StarImage.cpp b/source/core/StarImage.cpp new file mode 100644 index 0000000..bd3e838 --- /dev/null +++ b/source/core/StarImage.cpp @@ -0,0 +1,525 @@ +#include "StarImage.hpp" +#include "StarLogging.hpp" + +#include <png.h> + +namespace Star { + +Image Image::readPng(IODevicePtr device) { + auto logPngError = [](png_structp png_ptr, png_const_charp c) { + Logger::debug("PNG error in file: '%s', %s", (char*)png_get_error_ptr(png_ptr), c); + }; + + auto readPngData = [](png_structp pngPtr, png_bytep data, png_size_t length) { + IODevice* device = (IODevice*)png_get_io_ptr(pngPtr); + device->readFull((char*)data, length); + }; + + png_byte header[8]; + device->readFull((char*)header, sizeof(header)); + + if (png_sig_cmp(header, 0, sizeof(header))) + throw ImageException(strf("File %s is not a png image!", device->deviceName())); + + png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + if (!png_ptr) + throw ImageException("Internal libPNG error"); + + // Use custom warning function to suppress cerr warnings + png_set_error_fn(png_ptr, (png_voidp)device->deviceName().utf8Ptr(), logPngError, logPngError); + + png_infop info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + throw ImageException("Internal libPNG error"); + } + + png_infop end_info = png_create_info_struct(png_ptr); + if (!end_info) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + throw ImageException("Internal libPNG error"); + } + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_read_struct(&png_ptr, &info_ptr, &end_info); + throw ImageException("Internal error reading png."); + } + + png_set_read_fn(png_ptr, device.get(), readPngData); + + // Tell libPNG that we read some of the header. + png_set_sig_bytes(png_ptr, sizeof(header)); + + png_read_info(png_ptr, info_ptr); + + png_uint_32 img_width = png_get_image_width(png_ptr, info_ptr); + png_uint_32 img_height = png_get_image_height(png_ptr, info_ptr); + + png_uint_32 bitdepth = png_get_bit_depth(png_ptr, info_ptr); + png_uint_32 channels = png_get_channels(png_ptr, info_ptr); + + // Color type. (RGB, RGBA, Luminance, luminance alpha... palette... etc) + png_uint_32 color_type = png_get_color_type(png_ptr, info_ptr); + + if (color_type == PNG_COLOR_TYPE_PALETTE) { + png_set_palette_to_rgb(png_ptr); + channels = 3; + bitdepth = 8; + } + + if (color_type == PNG_COLOR_TYPE_GRAY || color_type == PNG_COLOR_TYPE_GRAY_ALPHA) { + if (bitdepth < 8) { + png_set_expand_gray_1_2_4_to_8(png_ptr); + bitdepth = 8; + } + png_set_gray_to_rgb(png_ptr); + if (color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + channels = 4; + else + channels = 3; + } + + // If the image has a transperancy set, convert it to a full alpha channel + if (png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS)) { + png_set_tRNS_to_alpha(png_ptr); + channels += 1; + } + + // We don't support 16 bit precision.. so if the image Has 16 bits per channel + // precision... round it down to 8. + if (bitdepth == 16) { + png_set_strip_16(png_ptr); + bitdepth = 8; + } + + if (bitdepth != 8 || (channels != 3 && channels != 4)) { + png_destroy_read_struct(&png_ptr, &info_ptr, &end_info); + throw ImageException(strf("Unsupported PNG pixel format in file %s", device->deviceName())); + } + + Image image(img_width, img_height, channels == 3 ? PixelFormat::RGB24 : PixelFormat::RGBA32); + + std::unique_ptr<png_bytep[]> row_ptrs(new png_bytep[img_height]); + size_t stride = img_width * channels; + for (size_t i = 0; i < img_height; ++i) + row_ptrs[i] = (png_bytep)image.data() + (img_height - i - 1) * stride; + + png_read_image(png_ptr, row_ptrs.get()); + png_destroy_read_struct(&png_ptr, &info_ptr, &end_info); + + return image; +} + +tuple<Vec2U, PixelFormat> Image::readPngMetadata(IODevicePtr device) { + auto logPngError = [](png_structp png_ptr, png_const_charp c) { + Logger::debug("PNG error in file: '%s', %s", (char*)png_get_error_ptr(png_ptr), c); + }; + + auto readPngData = [](png_structp pngPtr, png_bytep data, png_size_t length) { + IODevice* device = (IODevice*)png_get_io_ptr(pngPtr); + device->readFull((char*)data, length); + }; + + png_byte header[8]; + device->readFull((char*)header, sizeof(header)); + + if (png_sig_cmp(header, 0, sizeof(header))) + throw ImageException(strf("File %s is not a png image!", device->deviceName())); + + png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + if (!png_ptr) + throw ImageException("Internal libPNG error"); + + // Use custom warning function to suppress cerr warnings + png_set_error_fn(png_ptr, (png_voidp)device->deviceName().utf8Ptr(), logPngError, logPngError); + + png_infop info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + throw ImageException("Internal libPNG error"); + } + + png_infop end_info = png_create_info_struct(png_ptr); + if (!end_info) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + throw ImageException("Internal libPNG error"); + } + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_read_struct(&png_ptr, &info_ptr, &end_info); + throw ImageException("Internal error reading png."); + } + + png_set_read_fn(png_ptr, device.get(), readPngData); + + // Tell libPNG that we read some of the header. + png_set_sig_bytes(png_ptr, sizeof(header)); + + png_read_info(png_ptr, info_ptr); + + png_uint_32 img_width = png_get_image_width(png_ptr, info_ptr); + png_uint_32 img_height = png_get_image_height(png_ptr, info_ptr); + + png_uint_32 bitdepth = png_get_bit_depth(png_ptr, info_ptr); + png_uint_32 channels = png_get_channels(png_ptr, info_ptr); + + // Color type. (RGB, RGBA, Luminance, luminance alpha... palette... etc) + png_uint_32 color_type = png_get_color_type(png_ptr, info_ptr); + + if (color_type == PNG_COLOR_TYPE_PALETTE) { + png_set_palette_to_rgb(png_ptr); + channels = 3; + bitdepth = 8; + } + + if (color_type == PNG_COLOR_TYPE_GRAY || color_type == PNG_COLOR_TYPE_GRAY_ALPHA) { + if (bitdepth < 8) { + png_set_expand_gray_1_2_4_to_8(png_ptr); + bitdepth = 8; + } + png_set_gray_to_rgb(png_ptr); + if (color_type == PNG_COLOR_TYPE_GRAY_ALPHA) + channels = 4; + else + channels = 3; + } + + // If the image has a transperancy set, convert it to a full alpha channel + if (png_get_valid(png_ptr, info_ptr, PNG_INFO_tRNS)) { + png_set_tRNS_to_alpha(png_ptr); + channels += 1; + } + + Vec2U imageSize{img_width, img_height}; + PixelFormat pixelFormat = channels == 3 ? PixelFormat::RGB24 : PixelFormat::RGBA32; + + return make_tuple(imageSize, pixelFormat); +} + +Image Image::filled(Vec2U size, Vec4B color, PixelFormat pf) { + Image image(size, pf); + image.fill(color); + return image; +} + +Image::Image(PixelFormat pf) + : m_data(nullptr), m_width(0), m_height(0), m_pixelFormat(pf) {} + +Image::Image(Vec2U size, PixelFormat pf) + : Image(size[0], size[1], pf) {} + +Image::Image(unsigned width, unsigned height, PixelFormat pf) + : Image(pf) { + reset(width, height, pf); +} + +Image::~Image() { + if (m_data) + Star::free(m_data); +} + +Image::Image(Image const& image) : Image() { + operator=(image); +} + +Image::Image(Image&& image) : Image() { + operator=(move(image)); +} + +Image& Image::operator=(Image const& image) { + reset(image.m_width, image.m_height, image.m_pixelFormat); + memcpy(data(), image.data(), m_width * m_height * bytesPerPixel()); + return *this; +} + +Image& Image::operator=(Image&& image) { + reset(0, 0, m_pixelFormat); + + m_data = take(image.m_data); + m_width = image.m_width; + m_height = image.m_height; + m_pixelFormat = image.m_pixelFormat; + return *this; +} + +void Image::reset(Vec2U size, Maybe<PixelFormat> pf) { + reset(size[0], size[1], pf); +} + +void Image::reset(unsigned width, unsigned height, Maybe<PixelFormat> pf) { + if (!pf) + pf = m_pixelFormat; + + if (m_width == width && m_height == height && m_pixelFormat == *pf) + return; + + size_t imageSize = width * height * Star::bytesPerPixel(*pf); + if (imageSize == 0) { + if (m_data) { + Star::free(m_data); + m_data = nullptr; + } + } else { + uint8_t* newData = nullptr; + if (!m_data) + newData = (uint8_t*)Star::malloc(imageSize); + else + newData = (uint8_t*)Star::realloc(m_data, imageSize); + + if (!newData) + throw MemoryException::format("Could not allocate memory for new Image size %s\n", imageSize); + + m_data = newData; + memset(m_data, 0, imageSize); + } + + m_pixelFormat = *pf; + m_width = width; + m_height = height; +} + +void Image::fill(Vec3B const& c) { + if (bitsPerPixel() == 24) { + for (unsigned y = 0; y < m_height; ++y) + for (unsigned x = 0; x < m_width; ++x) + set24(x, y, c); + } else { + for (unsigned y = 0; y < m_height; ++y) + for (unsigned x = 0; x < m_width; ++x) + set32(x, y, Vec4B(c, 255)); + } +} + +void Image::fill(Vec4B const& c) { + if (bitsPerPixel() == 24) { + for (unsigned y = 0; y < m_height; ++y) + for (unsigned x = 0; x < m_width; ++x) + set24(x, y, c.vec3()); + } else { + for (unsigned y = 0; y < m_height; ++y) + for (unsigned x = 0; x < m_width; ++x) + set32(x, y, c); + } +} + +void Image::fillRect(Vec2U const& pos, Vec2U const& size, Vec3B const& c) { + for (unsigned y = pos[1]; y < pos[1] + size[1] && y < m_height; ++y) + for (unsigned x = pos[0]; x < pos[0] + size[0] && x < m_width; ++x) + set(Vec2U(x, y), c); +} + +void Image::fillRect(Vec2U const& pos, Vec2U const& size, Vec4B const& c) { + for (unsigned y = pos[1]; y < pos[1] + size[1] && y < m_height; ++y) + for (unsigned x = pos[0]; x < pos[0] + size[0] && x < m_width; ++x) + set(Vec2U(x, y), c); +} + +void Image::set(Vec2U const& pos, Vec4B const& c) { + if (pos[0] >= m_width || pos[1] >= m_height) { + throw ImageException(strf("%s out of range in Image::set", pos)); + } else if (bytesPerPixel() == 4) { + size_t offset = pos[1] * m_width * 4 + pos[0] * 4; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; + m_data[offset + 3] = c[3]; + } else if (bytesPerPixel() == 3) { + size_t offset = pos[1] * m_width * 3 + pos[0] * 3; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; + } +} + +void Image::set(Vec2U const& pos, Vec3B const& c) { + if (pos[0] >= m_width || pos[1] >= m_height) { + throw ImageException(strf("%s out of range in Image::set", pos)); + } else if (bytesPerPixel() == 4) { + size_t offset = pos[1] * m_width * 4 + pos[0] * 4; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; + m_data[offset + 3] = 255; + } else if (bytesPerPixel() == 3) { + size_t offset = pos[1] * m_width * 3 + pos[0] * 3; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; + } +} + +Vec4B Image::get(Vec2U const& pos) const { + Vec4B c; + if (pos[0] >= m_width || pos[1] >= m_height) { + throw ImageException(strf("%s out of range in Image::get", pos)); + } else if (bytesPerPixel() == 4) { + size_t offset = pos[1] * m_width * 4 + pos[0] * 4; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + c[3] = m_data[offset + 3]; + } else if (bytesPerPixel() == 3) { + size_t offset = pos[1] * m_width * 3 + pos[0] * 3; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + c[3] = 255; + } + return c; +} + +void Image::setrgb(Vec2U const& pos, Vec4B const& c) { + if (m_pixelFormat == PixelFormat::BGR24 || m_pixelFormat == PixelFormat::BGRA32) + set(pos, Vec4B{c[2], c[1], c[0], c[3]}); + else + set(pos, c); +} + +void Image::setrgb(Vec2U const& pos, Vec3B const& c) { + if (m_pixelFormat == PixelFormat::BGR24 || m_pixelFormat == PixelFormat::BGRA32) + set(pos, Vec3B{c[2], c[1], c[0]}); + else + set(pos, c); +} + +Vec4B Image::getrgb(Vec2U const& pos) const { + auto c = get(pos); + if (m_pixelFormat == PixelFormat::BGR24 || m_pixelFormat == PixelFormat::BGRA32) + return Vec4B{c[2], c[1], c[0], c[3]}; + else + return c; +} + +Vec4B Image::clamp(Vec2I const& pos) const { + Vec4B c; + unsigned x = (unsigned)Star::clamp<int>(pos[0], 0, m_width - 1); + unsigned y = (unsigned)Star::clamp<int>(pos[1], 0, m_height - 1); + if (m_width == 0 || m_height == 0) { + return {0, 0, 0, 0}; + } else if (bytesPerPixel() == 4) { + size_t offset = y * m_width * 4 + x * 4; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + c[3] = m_data[offset + 3]; + } else if (bytesPerPixel() == 3) { + size_t offset = y * m_width * 3 + x * 3; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + c[3] = 255; + } + return c; +} + +Vec4B Image::clamprgb(Vec2I const& pos) const { + auto c = clamp(pos); + if (m_pixelFormat == PixelFormat::BGR24 || m_pixelFormat == PixelFormat::BGRA32) + return Vec4B{c[2], c[1], c[0], c[3]}; + else + return c; +} + +Image Image::subImage(Vec2U const& pos, Vec2U const& size) const { + if (pos[0] + size[0] > m_width || pos[1] + size[1] > m_height) + throw ImageException(strf("call to subImage with pos %s size %s out of image bounds (%s, %s)", pos, size, m_width, m_height)); + + Image sub(size[0], size[1], m_pixelFormat); + + for (unsigned y = 0; y < size[1]; ++y) { + for (unsigned x = 0; x < size[0]; ++x) { + sub.set({x, y}, get(pos + Vec2U(x, y))); + } + } + + return sub; +} + +void Image::copyInto(Vec2U const& min, Image const& image) { + Vec2U max = (min + image.size()).piecewiseMin(size()); + + for (unsigned y = min[1]; y < max[1]; ++y) { + for (unsigned x = min[0]; x < max[0]; ++x) + set(x, y, image.get(Vec2U(x, y) - min)); + } +} + +void Image::drawInto(Vec2U const& min, Image const& image) { + Vec2U max = (min + image.size()).piecewiseMin(size()); + + for (unsigned y = min[1]; y < max[1]; ++y) { + for (unsigned x = min[0]; x < max[0]; ++x) { + Vec4B dest = get(Vec2U(x, y)); + Vec4B src = image.get(Vec2U(x, y) - min); + + Vec3U destMultiplied = Vec3U(dest[0], dest[1], dest[2]) * dest[3] / 255; + Vec3U srcMultiplied = Vec3U(src[0], src[1], src[2]) * src[3] / 255; + + // Src over dest alpha composition + Vec3U over = srcMultiplied + destMultiplied * (255 - src[3]) / 255; + unsigned alpha = src[3] + dest[3] * (255 - src[3]) / 255; + + set(x, y, Vec4B(over[0], over[1], over[2], alpha)); + } + } +} + +Image Image::convert(PixelFormat pixelFormat) const { + Image converted(m_width, m_height, pixelFormat); + converted.copyInto(Vec2U(), *this); + return converted; +} + +void Image::writePng(IODevicePtr device) const { + auto writePngData = [](png_structp pngPtr, png_bytep data, png_size_t length) { + IODevice* device = (IODevice*)png_get_io_ptr(pngPtr); + device->writeFull((char*)data, length); + }; + + auto flushPngData = [](png_structp) {}; + + png_structp png_ptr = nullptr; + png_infop info_ptr = nullptr; + + png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + if (!png_ptr) + throw ImageException("Internal libPNG error"); + + info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_write_struct(&png_ptr, nullptr); + throw ImageException("Internal libPNG error"); + } + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_write_struct(&png_ptr, &info_ptr); + throw ImageException("Internal error reading png."); + } + + unsigned channels = m_pixelFormat == PixelFormat::RGB24 ? 3 : 4; + + png_set_IHDR(png_ptr, + info_ptr, + m_width, + m_height, + 8, + channels == 3 ? PNG_COLOR_TYPE_RGB : PNG_COLOR_TYPE_RGBA, + PNG_INTERLACE_NONE, + PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT); + + unique_ptr<png_bytep[]> row_ptrs(new png_bytep[m_height]); + size_t stride = m_width * 8 * channels / 8; + for (size_t i = 0; i < m_height; ++i) { + size_t q = (m_height - i - 1) * stride; + row_ptrs[i] = (png_bytep)m_data + q; + } + + png_set_write_fn(png_ptr, device.get(), writePngData, flushPngData); + png_set_rows(png_ptr, info_ptr, row_ptrs.get()); + png_write_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, nullptr); + + png_destroy_write_struct(&png_ptr, &info_ptr); +} + +} diff --git a/source/core/StarImage.hpp b/source/core/StarImage.hpp new file mode 100644 index 0000000..cce0d2f --- /dev/null +++ b/source/core/StarImage.hpp @@ -0,0 +1,313 @@ +#ifndef STAR_IMAGE_HPP +#define STAR_IMAGE_HPP + +#include "StarString.hpp" +#include "StarVector.hpp" +#include "StarIODevice.hpp" + +namespace Star { + +enum class PixelFormat { + RGB24, + RGBA32, + BGR24, + BGRA32 +}; + +uint8_t bitsPerPixel(PixelFormat pf); +uint8_t bytesPerPixel(PixelFormat pf); + +STAR_EXCEPTION(ImageException, StarException); + +STAR_CLASS(Image); + +// Holds an image of PixelFormat in row major order, with no padding, with (0, +// 0) defined to be the *lower left* corner. +class Image { +public: + static Image readPng(IODevicePtr device); + // Returns the size and pixel format that would be constructed from the given + // png file, without actually loading it. + static tuple<Vec2U, PixelFormat> readPngMetadata(IODevicePtr device); + + static Image filled(Vec2U size, Vec4B color, PixelFormat pf = PixelFormat::RGBA32); + + // Creates a zero size image + Image(PixelFormat pf = PixelFormat::RGBA32); + Image(Vec2U size, PixelFormat pf = PixelFormat::RGBA32); + Image(unsigned width, unsigned height, PixelFormat pf = PixelFormat::RGBA32); + ~Image(); + + Image(Image const& image); + Image(Image&& image); + + Image& operator=(Image const& image); + Image& operator=(Image&& image); + + uint8_t bitsPerPixel() const; + uint8_t bytesPerPixel() const; + + unsigned width() const; + unsigned height() const; + Vec2U size() const; + // width or height is 0 + bool empty() const; + + PixelFormat pixelFormat() const; + + // If the image is empty, the data ptr will be null + uint8_t const* data() const; + uint8_t* data(); + + // Reallocate the image with the given width, height, and pixel format. The + // contents of the image are always zeroed after a call to reset. + void reset(Vec2U size, Maybe<PixelFormat> pf = {}); + void reset(unsigned width, unsigned height, Maybe<PixelFormat> pf = {}); + + // Fill the image with a given color + void fill(Vec3B const& c); + void fill(Vec4B const& c); + + // Fill a rectangle with a given color + void fillRect(Vec2U const& pos, Vec2U const& size, Vec3B const& c); + void fillRect(Vec2U const& pos, Vec2U const& size, Vec4B const& c); + + // Color parameters / return values here are in whatever the internal format + // is. Fourth byte, if missing or not provided, is assumed to be 255. If + // the position is out of range, then throws an exception. + void set(Vec2U const& pos, Vec4B const& c); + void set(Vec2U const& pos, Vec3B const& c); + Vec4B get(Vec2U const& pos) const; + + // Same as set / get, except color parameters / return values here are always + // RGB[A], and converts if necessary. + void setrgb(Vec2U const& pos, Vec4B const& c); + void setrgb(Vec2U const& pos, Vec3B const& c); + Vec4B getrgb(Vec2U const& pos) const; + + // Get pixel value, but if pos is out of the normal pixel range, it is + // clamped back into the valid pixel range. Returns (0, 0, 0, 0) if image is + // empty. + Vec4B clamp(Vec2I const& pos) const; + Vec4B clamprgb(Vec2I const& pos) const; + + // x / y versions of set / get, for compatibility + void set(unsigned x, unsigned y, Vec4B const& c); + void set(unsigned x, unsigned y, Vec3B const& c); + Vec4B get(unsigned x, unsigned y) const; + void setrgb(unsigned x, unsigned y, Vec4B const& c); + void setrgb(unsigned x, unsigned y, Vec3B const& c); + Vec4B getrgb(unsigned x, unsigned y) const; + Vec4B clamp(int x, int y) const; + Vec4B clamprgb(int x, int y) const; + + // Must be 32 bitsPerPixel, no format conversion or bounds checking takes + // place. Very fast inline versions. + void set32(Vec2U const& pos, Vec4B const& c); + void set32(unsigned x, unsigned y, Vec4B const& c); + Vec4B get32(unsigned x, unsigned y) const; + + // Must be 24 bitsPerPixel, no format conversion or bounds checking takes + // place. Very fast inline versions. + void set24(Vec2U const& pos, Vec3B const& c); + void set24(unsigned x, unsigned y, Vec3B const& c); + Vec3B get24(unsigned x, unsigned y) const; + + // Called as callback(unsigned x, unsigned y, Vec4B const& pixel) + template <typename CallbackType> + void forEachPixel(CallbackType&& callback) const; + + // Called as callback(unsigned x, unsigned y, Vec4B& pixel) + template <typename CallbackType> + void forEachPixel(CallbackType&& callback); + + // Pixel rectangle, lower left position and size of rectangle. + Image subImage(Vec2U const& pos, Vec2U const& size) const; + + // Copy given image into this one at pos + void copyInto(Vec2U const& pos, Image const& image); + // Draw given image over this one at pos (with alpha composition) + void drawInto(Vec2U const& pos, Image const& image); + + // Convert this image into the given pixel format + Image convert(PixelFormat pixelFormat) const; + + void writePng(IODevicePtr device) const; + +private: + uint8_t* m_data; + unsigned m_width; + unsigned m_height; + PixelFormat m_pixelFormat; +}; + +inline uint8_t bitsPerPixel(PixelFormat pf) { + switch (pf) { + case PixelFormat::RGB24: + return 24; + case PixelFormat::RGBA32: + return 32; + case PixelFormat::BGR24: + return 24; + default: + return 32; + } +} + +inline uint8_t bytesPerPixel(PixelFormat pf) { + switch (pf) { + case PixelFormat::RGB24: + return 3; + case PixelFormat::RGBA32: + return 4; + case PixelFormat::BGR24: + return 3; + default: + return 4; + } +} + +inline uint8_t Image::bitsPerPixel() const { + return Star::bitsPerPixel(m_pixelFormat); +} + +inline uint8_t Image::bytesPerPixel() const { + return Star::bytesPerPixel(m_pixelFormat); +} + +inline unsigned Image::width() const { + return m_width; +} + +inline unsigned Image::height() const { + return m_height; +} + +inline bool Image::empty() const { + return m_width == 0 || m_height == 0; +} + +inline Vec2U Image::size() const { + return {m_width, m_height}; +} + +inline PixelFormat Image::pixelFormat() const { + return m_pixelFormat; +} + +inline const uint8_t* Image::data() const { + return m_data; +} + +inline uint8_t* Image::data() { + return m_data; +} + +inline void Image::set(unsigned x, unsigned y, Vec4B const& c) { + return set({x, y}, c); +} + +inline void Image::set(unsigned x, unsigned y, Vec3B const& c) { + return set({x, y}, c); +} + +inline Vec4B Image::get(unsigned x, unsigned y) const { + return get({x, y}); +} + +inline void Image::setrgb(unsigned x, unsigned y, Vec4B const& c) { + return setrgb({x, y}, c); +} + +inline void Image::setrgb(unsigned x, unsigned y, Vec3B const& c) { + return setrgb({x, y}, c); +} + +inline Vec4B Image::getrgb(unsigned x, unsigned y) const { + return getrgb({x, y}); +} + +inline Vec4B Image::clamp(int x, int y) const { + return clamp({x, y}); +} + +inline Vec4B Image::clamprgb(int x, int y) const { + return clamprgb({x, y}); +} + +inline void Image::set32(Vec2U const& pos, Vec4B const& c) { + set32(pos[0], pos[1], c); +} + +inline void Image::set32(unsigned x, unsigned y, Vec4B const& c) { + starAssert(m_data && x < m_width && y < m_height); + starAssert(bytesPerPixel() == 4); + + size_t offset = y * m_width * 4 + x * 4; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; + m_data[offset + 3] = c[3]; +} + +inline Vec4B Image::get32(unsigned x, unsigned y) const { + starAssert(m_data && x < m_width && y < m_height); + starAssert(bytesPerPixel() == 4); + + Vec4B c; + size_t offset = y * m_width * 4 + x * 4; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + c[3] = m_data[offset + 3]; + return c; +} + +inline void Image::set24(Vec2U const& pos, Vec3B const& c) { + set24(pos[0], pos[1], c); +} + +inline void Image::set24(unsigned x, unsigned y, Vec3B const& c) { + starAssert(m_data && x < m_width && y < m_height); + starAssert(bytesPerPixel() == 3); + + size_t offset = y * m_width * 3 + x * 3; + m_data[offset] = c[0]; + m_data[offset + 1] = c[1]; + m_data[offset + 2] = c[2]; +} + +inline Vec3B Image::get24(unsigned x, unsigned y) const { + starAssert(m_data && x < m_width && y < m_height); + starAssert(bytesPerPixel() == 3); + + Vec3B c; + size_t offset = y * m_width * 3 + x * 3; + c[0] = m_data[offset]; + c[1] = m_data[offset + 1]; + c[2] = m_data[offset + 2]; + return c; +} + +template <typename CallbackType> +void Image::forEachPixel(CallbackType&& callback) const { + for (unsigned y = 0; y < m_height; y++) { + for (unsigned x = 0; x < m_width; x++) + callback(x, y, get(x, y)); + } +} + +template <typename CallbackType> +void Image::forEachPixel(CallbackType&& callback) { + for (unsigned y = 0; y < m_height; y++) { + for (unsigned x = 0; x < m_width; x++) { + Vec4B pixel = get(x, y); + callback(x, y, pixel); + set(x, y, pixel); + } + } +} + +} + +#endif diff --git a/source/core/StarImageProcessing.cpp b/source/core/StarImageProcessing.cpp new file mode 100644 index 0000000..3e1db89 --- /dev/null +++ b/source/core/StarImageProcessing.cpp @@ -0,0 +1,550 @@ +#include "StarImageProcessing.hpp" +#include "StarMatrix3.hpp" +#include "StarInterpolation.hpp" +#include "StarLexicalCast.hpp" +#include "StarColor.hpp" +#include "StarImage.hpp" + +namespace Star { + +Image scaleNearest(Image const& srcImage, Vec2F const& scale) { + Vec2U srcSize = srcImage.size(); + Vec2U destSize = Vec2U::round(vmult(Vec2F(srcSize), scale)); + destSize[0] = max(destSize[0], 1u); + destSize[1] = max(destSize[1], 1u); + + Image destImage(destSize, srcImage.pixelFormat()); + + for (unsigned y = 0; y < destSize[1]; ++y) { + for (unsigned x = 0; x < destSize[0]; ++x) + destImage.set({x, y}, srcImage.clamp(Vec2I::round(vdiv(Vec2F(x, y), scale)))); + } + return destImage; +} + +Image scaleBilinear(Image const& srcImage, Vec2F const& scale) { + Vec2U srcSize = srcImage.size(); + Vec2U destSize = Vec2U::round(vmult(Vec2F(srcSize), scale)); + destSize[0] = max(destSize[0], 1u); + destSize[1] = max(destSize[1], 1u); + + Image destImage(destSize, srcImage.pixelFormat()); + + for (unsigned y = 0; y < destSize[1]; ++y) { + for (unsigned x = 0; x < destSize[0]; ++x) { + auto pos = vdiv(Vec2F(x, y), scale); + auto ipart = Vec2I::floor(pos); + auto fpart = pos - Vec2F(ipart); + + auto result = lerp(fpart[1], lerp(fpart[0], Vec4F(srcImage.clamp(ipart[0], ipart[1])), Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1]))), lerp(fpart[0], + Vec4F(srcImage.clamp(ipart[0], ipart[1] + 1)), Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1] + 1)))); + + destImage.set({x, y}, Vec4B(result)); + } + } + + return destImage; +} + +Image scaleBicubic(Image const& srcImage, Vec2F const& scale) { + Vec2U srcSize = srcImage.size(); + Vec2U destSize = Vec2U::round(vmult(Vec2F(srcSize), scale)); + destSize[0] = max(destSize[0], 1u); + destSize[1] = max(destSize[1], 1u); + + Image destImage(destSize, srcImage.pixelFormat()); + + for (unsigned y = 0; y < destSize[1]; ++y) { + for (unsigned x = 0; x < destSize[0]; ++x) { + auto pos = vdiv(Vec2F(x, y), scale); + auto ipart = Vec2I::floor(pos); + auto fpart = pos - Vec2F(ipart); + + Vec4F a = cubic4(fpart[0], + Vec4F(srcImage.clamp(ipart[0], ipart[1])), + Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1])), + Vec4F(srcImage.clamp(ipart[0] + 2, ipart[1])), + Vec4F(srcImage.clamp(ipart[0] + 3, ipart[1]))); + + Vec4F b = cubic4(fpart[0], + Vec4F(srcImage.clamp(ipart[0], ipart[1] + 1)), + Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1] + 1)), + Vec4F(srcImage.clamp(ipart[0] + 2, ipart[1] + 1)), + Vec4F(srcImage.clamp(ipart[0] + 3, ipart[1] + 1))); + + Vec4F c = cubic4(fpart[0], + Vec4F(srcImage.clamp(ipart[0], ipart[1] + 2)), + Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1] + 2)), + Vec4F(srcImage.clamp(ipart[0] + 2, ipart[1] + 2)), + Vec4F(srcImage.clamp(ipart[0] + 3, ipart[1] + 2))); + + Vec4F d = cubic4(fpart[0], + Vec4F(srcImage.clamp(ipart[0], ipart[1] + 3)), + Vec4F(srcImage.clamp(ipart[0] + 1, ipart[1] + 3)), + Vec4F(srcImage.clamp(ipart[0] + 2, ipart[1] + 3)), + Vec4F(srcImage.clamp(ipart[0] + 3, ipart[1] + 3))); + + auto result = cubic4(fpart[1], a, b, c, d); + + destImage.set({x, y}, Vec4B( + clamp(result[0], 0.0f, 255.0f), + clamp(result[1], 0.0f, 255.0f), + clamp(result[2], 0.0f, 255.0f), + clamp(result[3], 0.0f, 255.0f) + )); + } + } + + return destImage; +} + +StringList colorDirectivesFromConfig(JsonArray const& directives) { + List<String> result; + + for (auto entry : directives) { + if (entry.type() == Json::Type::String) { + result.append(entry.toString()); + } else if (entry.type() == Json::Type::Object) { + result.append(paletteSwapDirectivesFromConfig(entry)); + } else { + throw StarException("Malformed color directives list."); + } + } + return result; +} + +String paletteSwapDirectivesFromConfig(Json const& swaps) { + ColorReplaceImageOperation paletteSwaps; + for (auto const& swap : swaps.iterateObject()) + paletteSwaps.colorReplaceMap[Color::fromHex(swap.first).toRgba()] = Color::fromHex(swap.second.toString()).toRgba(); + return "?" + imageOperationToString(paletteSwaps); +} + +HueShiftImageOperation HueShiftImageOperation::hueShiftDegrees(float degrees) { + return HueShiftImageOperation{degrees / 360.0f}; +} + +SaturationShiftImageOperation SaturationShiftImageOperation::saturationShift100(float amount) { + return SaturationShiftImageOperation{amount / 100.0f}; +} + +BrightnessMultiplyImageOperation BrightnessMultiplyImageOperation::brightnessMultiply100(float amount) { + return BrightnessMultiplyImageOperation{amount / 100.0f + 1.0f}; +} + +FadeToColorImageOperation::FadeToColorImageOperation(Vec3B color, float amount) { + this->color = color; + this->amount = amount; + + auto fcl = Color::rgb(color).toLinear(); + for (int i = 0; i <= 255; ++i) { + auto r = Color::rgb(Vec3B(i, i, i)).toLinear().mix(fcl, amount).toSRGB().toRgb(); + rTable[i] = r[0]; + gTable[i] = r[1]; + bTable[i] = r[2]; + } +} + +ImageOperation imageOperationFromString(String const& string) { + try { + auto bits = string.splitAny("=;"); + String type = bits.at(0); + + if (type == "hueshift") { + return HueShiftImageOperation::hueShiftDegrees(lexicalCast<float>(bits.at(1))); + + } else if (type == "saturation") { + return SaturationShiftImageOperation::saturationShift100(lexicalCast<float>(bits.at(1))); + + } else if (type == "brightness") { + return BrightnessMultiplyImageOperation::brightnessMultiply100(lexicalCast<float>(bits.at(1))); + + } else if (type == "fade") { + return FadeToColorImageOperation(Color::fromHex(bits.at(1)).toRgb(), lexicalCast<float>(bits.at(2))); + + } else if (type == "scanlines") { + return ScanLinesImageOperation{ + FadeToColorImageOperation(Color::fromHex(bits.at(1)).toRgb(), lexicalCast<float>(bits.at(2))), + FadeToColorImageOperation(Color::fromHex(bits.at(3)).toRgb(), lexicalCast<float>(bits.at(4)))}; + + } else if (type == "setcolor") { + return SetColorImageOperation{Color::fromHex(bits.at(1)).toRgb()}; + + } else if (type == "replace") { + ColorReplaceImageOperation operation; + for (size_t i = 0; i < (bits.size() - 1) / 2; ++i) + operation.colorReplaceMap[Color::fromHex(bits[i * 2 + 1]).toRgba()] = Color::fromHex(bits[i * 2 + 2]).toRgba(); + + return operation; + + } else if (type == "addmask" || type == "submask") { + AlphaMaskImageOperation operation; + if (type == "addmask") + operation.mode = AlphaMaskImageOperation::Additive; + else + operation.mode = AlphaMaskImageOperation::Subtractive; + + operation.maskImages = bits.at(1).split('+'); + + if (bits.size() > 2) + operation.offset[0] = lexicalCast<int>(bits.at(2)); + + if (bits.size() > 3) + operation.offset[1] = lexicalCast<int>(bits.at(3)); + + return operation; + + } else if (type == "blendmult" || type == "blendscreen") { + BlendImageOperation operation; + + if (type == "blendmult") + operation.mode = BlendImageOperation::Multiply; + else + operation.mode = BlendImageOperation::Screen; + + operation.blendImages = bits.at(1).split('+'); + + if (bits.size() > 2) + operation.offset[0] = lexicalCast<int>(bits.at(2)); + + if (bits.size() > 3) + operation.offset[1] = lexicalCast<int>(bits.at(3)); + + return operation; + + } else if (type == "multiply") { + return MultiplyImageOperation{Color::fromHex(bits.at(1)).toRgba()}; + + } else if (type == "border" || type == "outline") { + BorderImageOperation operation; + operation.pixels = lexicalCast<unsigned>(bits.at(1)); + operation.startColor = Color::fromHex(bits.at(2)).toRgba(); + if (bits.size() > 3) + operation.endColor = Color::fromHex(bits.at(3)).toRgba(); + else + operation.endColor = operation.startColor; + operation.outlineOnly = type == "outline"; + + return operation; + + } else if (type == "scalenearest" || type == "scalebilinear" || type == "scalebicubic" || type == "scale") { + Vec2F scale; + if (bits.size() == 2) + scale = Vec2F::filled(lexicalCast<float>(bits.at(1))); + else + scale = Vec2F(lexicalCast<float>(bits.at(1)), lexicalCast<float>(bits.at(2))); + + ScaleImageOperation::Mode mode; + if (type == "scalenearest") + mode = ScaleImageOperation::Nearest; + else if (type == "scalebicubic") + mode = ScaleImageOperation::Bicubic; + else + mode = ScaleImageOperation::Bilinear; + + return ScaleImageOperation{mode, scale}; + + } else if (type == "crop") { + return CropImageOperation{RectI(lexicalCast<float>(bits.at(1)), lexicalCast<float>(bits.at(2)), + lexicalCast<float>(bits.at(3)), lexicalCast<float>(bits.at(4)))}; + + } else if (type == "flipx") { + return FlipImageOperation{FlipImageOperation::FlipX}; + + } else if (type == "flipy") { + return FlipImageOperation{FlipImageOperation::FlipY}; + + } else if (type == "flipxy") { + return FlipImageOperation{FlipImageOperation::FlipXY}; + + } else { + throw ImageOperationException(strf("Could not recognize ImageOperation type %s", type)); + } + } catch (OutOfRangeException const& e) { + throw ImageOperationException("Error reading ImageOperation", e); + } catch (BadLexicalCast const& e) { + throw ImageOperationException("Error reading ImageOperation", e); + } +} + +String imageOperationToString(ImageOperation const& operation) { + if (auto op = operation.ptr<HueShiftImageOperation>()) { + return strf("hueshift=%s", op->hueShiftAmount * 360.0f); + } else if (auto op = operation.ptr<SaturationShiftImageOperation>()) { + return strf("saturation=%s", op->saturationShiftAmount * 100.0f); + } else if (auto op = operation.ptr<BrightnessMultiplyImageOperation>()) { + return strf("brightness=%s", (op->brightnessMultiply - 1.0f) * 100.0f); + } else if (auto op = operation.ptr<FadeToColorImageOperation>()) { + return strf("fade=%s=%s", Color::rgb(op->color).toHex(), op->amount); + } else if (auto op = operation.ptr<ScanLinesImageOperation>()) { + return strf("scanlines=%s=%s=%s=%s", + Color::rgb(op->fade1.color).toHex(), + op->fade1.amount, + Color::rgb(op->fade2.color).toHex(), + op->fade2.amount); + } else if (auto op = operation.ptr<SetColorImageOperation>()) { + return strf("setcolor=%s", Color::rgb(op->color).toHex()); + } else if (auto op = operation.ptr<ColorReplaceImageOperation>()) { + String str = "replace"; + for (auto const& pair : op->colorReplaceMap) + str += strf(";%s=%s", Color::rgba(pair.first).toHex(), Color::rgba(pair.second).toHex()); + return str; + } else if (auto op = operation.ptr<AlphaMaskImageOperation>()) { + if (op->mode == AlphaMaskImageOperation::Additive) + return strf("addmask=%s;%s;%s", op->maskImages.join("+"), op->offset[0], op->offset[1]); + else if (op->mode == AlphaMaskImageOperation::Subtractive) + return strf("submask=%s;%s;%s", op->maskImages.join("+"), op->offset[0], op->offset[1]); + } else if (auto op = operation.ptr<BlendImageOperation>()) { + if (op->mode == BlendImageOperation::Multiply) + return strf("blendmult=%s;%s;%s", op->blendImages.join("+"), op->offset[0], op->offset[1]); + else if (op->mode == BlendImageOperation::Screen) + return strf("blendscreen=%s;%s;%s", op->blendImages.join("+"), op->offset[0], op->offset[1]); + } else if (auto op = operation.ptr<MultiplyImageOperation>()) { + return strf("multiply=%s", Color::rgba(op->color).toHex()); + } else if (auto op = operation.ptr<BorderImageOperation>()) { + if (op->outlineOnly) + return strf("outline=%d;%s;%s", op->pixels, Color::rgba(op->startColor).toHex(), Color::rgba(op->endColor).toHex()); + else + return strf("border=%d;%s;%s", op->pixels, Color::rgba(op->startColor).toHex(), Color::rgba(op->endColor).toHex()); + } else if (auto op = operation.ptr<ScaleImageOperation>()) { + if (op->mode == ScaleImageOperation::Nearest) + return strf("scalenearest=%s", op->scale); + else if (op->mode == ScaleImageOperation::Bilinear) + return strf("scalebilinear=%s", op->scale); + else if (op->mode == ScaleImageOperation::Bicubic) + return strf("scalebicubic=%s", op->scale); + } else if (auto op = operation.ptr<CropImageOperation>()) { + return strf("crop=%s;%s;%s;%s", op->subset.xMin(), op->subset.xMax(), op->subset.yMin(), op->subset.yMax()); + } else if (auto op = operation.ptr<FlipImageOperation>()) { + if (op->mode == FlipImageOperation::FlipX) + return "flipx"; + else if (op->mode == FlipImageOperation::FlipY) + return "flipy"; + else if (op->mode == FlipImageOperation::FlipXY) + return "flipxy"; + } + + return ""; +} + +List<ImageOperation> parseImageOperations(String const& params) { + List<ImageOperation> operations; + for (auto const& op : params.split('?')) { + if (!op.empty()) + operations.append(imageOperationFromString(op)); + } + return operations; +} + +String printImageOperations(List<ImageOperation> const& list) { + return StringList(list.transformed(imageOperationToString)).join("?"); +} + +StringList imageOperationReferences(List<ImageOperation> const& operations) { + StringList references; + for (auto const& operation : operations) { + if (auto op = operation.ptr<AlphaMaskImageOperation>()) + references.appendAll(op->maskImages); + else if (auto op = operation.ptr<BlendImageOperation>()) + references.appendAll(op->blendImages); + } + return references; +} + +Image processImageOperations(List<ImageOperation> const& operations, Image image, ImageReferenceCallback refCallback) { + for (auto const& operation : operations) { + if (auto op = operation.ptr<HueShiftImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + if (pixel[3] != 0) + pixel = Color::hueShiftVec4B(pixel, op->hueShiftAmount); + }); + } else if (auto op = operation.ptr<SaturationShiftImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + if (pixel[3] != 0) { + Color color = Color::rgba(pixel); + color.setSaturation(clamp(color.saturation() + op->saturationShiftAmount, 0.0f, 1.0f)); + pixel = color.toRgba(); + } + }); + } else if (auto op = operation.ptr<BrightnessMultiplyImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + if (pixel[3] != 0) { + Color color = Color::rgba(pixel); + color.setValue(clamp(color.value() * op->brightnessMultiply, 0.0f, 1.0f)); + pixel = color.toRgba(); + } + }); + } else if (auto op = operation.ptr<FadeToColorImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + pixel[0] = op->rTable[pixel[0]]; + pixel[1] = op->gTable[pixel[1]]; + pixel[2] = op->bTable[pixel[2]]; + }); + } else if (auto op = operation.ptr<ScanLinesImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned y, Vec4B& pixel) { + if (y % 2 == 0) { + pixel[0] = op->fade1.rTable[pixel[0]]; + pixel[1] = op->fade1.gTable[pixel[1]]; + pixel[2] = op->fade1.bTable[pixel[2]]; + } else { + pixel[0] = op->fade2.rTable[pixel[0]]; + pixel[1] = op->fade2.gTable[pixel[1]]; + pixel[2] = op->fade2.bTable[pixel[2]]; + } + }); + } else if (auto op = operation.ptr<SetColorImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + pixel[0] = op->color[0]; + pixel[1] = op->color[1]; + pixel[2] = op->color[2]; + }); + } else if (auto op = operation.ptr<ColorReplaceImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + if (auto m = op->colorReplaceMap.maybe(pixel)) + pixel = *m; + }); + + } else if (auto op = operation.ptr<AlphaMaskImageOperation>()) { + if (op->maskImages.empty()) + continue; + + if (!refCallback) + throw StarException("Missing image ref callback during AlphaMaskImageOperation in ImageProcessor::process"); + + List<Image const*> maskImages; + for (auto const& reference : op->maskImages) + maskImages.append(refCallback(reference)); + + image.forEachPixel([&op, &maskImages](unsigned x, unsigned y, Vec4B& pixel) { + uint8_t maskAlpha = 0; + Vec2U pos = Vec2U(Vec2I(x, y) + op->offset); + for (auto mask : maskImages) { + if (pos[0] < mask->width() && pos[1] < mask->height()) { + if (op->mode == AlphaMaskImageOperation::Additive) { + // We produce our mask alpha from the maximum alpha of any of + // the + // mask images. + maskAlpha = std::max(maskAlpha, mask->get(pos)[3]); + } else if (op->mode == AlphaMaskImageOperation::Subtractive) { + // We produce our mask alpha from the minimum alpha of any of + // the + // mask images. + maskAlpha = std::min(maskAlpha, mask->get(pos)[3]); + } + } + } + pixel[3] = std::min(pixel[3], maskAlpha); + }); + + } else if (auto op = operation.ptr<BlendImageOperation>()) { + if (op->blendImages.empty()) + continue; + + if (!refCallback) + throw StarException("Missing image ref callback during BlendImageOperation in ImageProcessor::process"); + + List<Image const*> blendImages; + for (auto const& reference : op->blendImages) + blendImages.append(refCallback(reference)); + + image.forEachPixel([&op, &blendImages](unsigned x, unsigned y, Vec4B& pixel) { + Vec2U pos = Vec2U(Vec2I(x, y) + op->offset); + Vec4F fpixel = Color::v4bToFloat(pixel); + for (auto blend : blendImages) { + if (pos[0] < blend->width() && pos[1] < blend->height()) { + Vec4F blendPixel = Color::v4bToFloat(blend->get(pos)); + if (op->mode == BlendImageOperation::Multiply) + fpixel = fpixel.piecewiseMultiply(blendPixel); + else if (op->mode == BlendImageOperation::Screen) + fpixel = Vec4F::filled(1.0f) - (Vec4F::filled(1.0f) - fpixel).piecewiseMultiply(Vec4F::filled(1.0f) - blendPixel); + } + } + pixel = Color::v4fToByte(fpixel); + }); + + } else if (auto op = operation.ptr<MultiplyImageOperation>()) { + image.forEachPixel([&op](unsigned, unsigned, Vec4B& pixel) { + pixel = pixel.combine(op->color, [](uint8_t a, uint8_t b) -> uint8_t { + return (uint8_t)(((int)a * (int)b) / 255); + }); + }); + + } else if (auto op = operation.ptr<BorderImageOperation>()) { + Image borderImage(image.size() + Vec2U::filled(op->pixels * 2), PixelFormat::RGBA32); + borderImage.copyInto(Vec2U::filled(op->pixels), image); + Vec2I borderImageSize = Vec2I(borderImage.size()); + + borderImage.forEachPixel([&op, &image, &borderImageSize](int x, int y, Vec4B& pixel) { + int pixels = op->pixels; + if (pixel[3] == 0) { + int dist = std::numeric_limits<int>::max(); + for (int j = -pixels; j < pixels + 1; j++) { + for (int i = -pixels; i < pixels + 1; i++) { + if (i + x >= pixels && j + y >= pixels && i + x < borderImageSize[0] - pixels && j + y < borderImageSize[1] - pixels) { + Vec4B remotePixel = image.get(i + x - pixels, j + y - pixels); + if (remotePixel[3] != 0) { + dist = std::min(dist, abs(i) + abs(j)); + if (dist == 1) // Early out, if dist is 1 it ain't getting shorter + break; + } + } + } + } + + if (dist < std::numeric_limits<int>::max()) { + float percent = (dist - 1) / (2.0f * pixels - 1); + pixel = Vec4B(Vec4F(op->startColor) * (1 - percent) + Vec4F(op->endColor) * percent); + } + } else if (op->outlineOnly) { + pixel = Vec4B(0, 0, 0, 0); + } + }); + + image = borderImage; + + } else if (auto op = operation.ptr<ScaleImageOperation>()) { + if (op->mode == ScaleImageOperation::Nearest) + image = scaleNearest(image, op->scale); + else if (op->mode == ScaleImageOperation::Bilinear) + image = scaleBilinear(image, op->scale); + else if (op->mode == ScaleImageOperation::Bicubic) + image = scaleBicubic(image, op->scale); + + } else if (auto op = operation.ptr<CropImageOperation>()) { + image = image.subImage(Vec2U(op->subset.min()), Vec2U(op->subset.size())); + + } else if (auto op = operation.ptr<FlipImageOperation>()) { + if (op->mode == FlipImageOperation::FlipX || op->mode == FlipImageOperation::FlipXY) { + for (size_t y = 0; y < image.height(); ++y) { + for (size_t xLeft = 0; xLeft < image.width() / 2; ++xLeft) { + size_t xRight = image.width() - 1 - xLeft; + + auto left = image.get(xLeft, y); + auto right = image.get(xRight, y); + + image.set(xLeft, y, right); + image.set(xRight, y, left); + } + } + } + + if (op->mode == FlipImageOperation::FlipY || op->mode == FlipImageOperation::FlipXY) { + for (size_t x = 0; x < image.width(); ++x) { + for (size_t yTop = 0; yTop < image.height() / 2; ++yTop) { + size_t yBottom = image.height() - 1 - yTop; + + auto top = image.get(x, yTop); + auto bottom = image.get(x, yBottom); + + image.set(x, yTop, bottom); + image.set(x, yBottom, top); + } + } + } + } + } + + return image; +} + +} diff --git a/source/core/StarImageProcessing.hpp b/source/core/StarImageProcessing.hpp new file mode 100644 index 0000000..e2008b8 --- /dev/null +++ b/source/core/StarImageProcessing.hpp @@ -0,0 +1,153 @@ +#ifndef STAR_IMAGE_PROCESSING_HPP +#define STAR_IMAGE_PROCESSING_HPP + +#include "StarList.hpp" +#include "StarRect.hpp" +#include "StarJson.hpp" + +namespace Star { + +STAR_CLASS(Image); + +STAR_EXCEPTION(ImageOperationException, StarException); + +Image scaleNearest(Image const& srcImage, Vec2F const& scale); +Image scaleBilinear(Image const& srcImage, Vec2F const& scale); +Image scaleBicubic(Image const& srcImage, Vec2F const& scale); + +StringList colorDirectivesFromConfig(JsonArray const& directives); +String paletteSwapDirectivesFromConfig(Json const& swaps); + +struct HueShiftImageOperation { + // Specify hue shift angle as -360 to 360 rather than -1 to 1 + static HueShiftImageOperation hueShiftDegrees(float degrees); + + // value here is normalized to 1.0 + float hueShiftAmount; +}; + +struct SaturationShiftImageOperation { + // Specify saturation shift as amount normalized to 100 + static SaturationShiftImageOperation saturationShift100(float amount); + + // value here is normalized to 1.0 + float saturationShiftAmount; +}; + +struct BrightnessMultiplyImageOperation { + // Specify brightness multiply as amount where 0 means "no change" and 100 + // means "x2" and -100 means "x0" + static BrightnessMultiplyImageOperation brightnessMultiply100(float amount); + + float brightnessMultiply; +}; + +// Fades R G and B channels to the given color by the given amount, ignores A +struct FadeToColorImageOperation { + FadeToColorImageOperation(Vec3B color, float amount); + + Vec3B color; + float amount; + + Array<uint8_t, 256> rTable; + Array<uint8_t, 256> gTable; + Array<uint8_t, 256> bTable; +}; + +// Applies two FadeToColor operations in alternating rows to produce a scanline effect +struct ScanLinesImageOperation { + FadeToColorImageOperation fade1; + FadeToColorImageOperation fade2; +}; + +// Sets RGB values to the given color, and ignores the alpha channel +struct SetColorImageOperation { + Vec3B color; +}; + +typedef HashMap<Vec4B, Vec4B> ColorReplaceMap; + +struct ColorReplaceImageOperation { + ColorReplaceMap colorReplaceMap; +}; + +struct AlphaMaskImageOperation { + enum MaskMode { + Additive, + Subtractive + }; + + MaskMode mode; + StringList maskImages; + Vec2I offset; +}; + +struct BlendImageOperation { + enum BlendMode { + Multiply, + Screen + }; + + BlendMode mode; + StringList blendImages; + Vec2I offset; +}; + +struct MultiplyImageOperation { + Vec4B color; +}; + +struct BorderImageOperation { + unsigned pixels; + Vec4B startColor; + Vec4B endColor; + bool outlineOnly; +}; + +struct ScaleImageOperation { + enum Mode { + Nearest, + Bilinear, + Bicubic + }; + + Mode mode; + Vec2F scale; +}; + +struct CropImageOperation { + RectI subset; +}; + +struct FlipImageOperation { + enum Mode { + FlipX, + FlipY, + FlipXY + }; + Mode mode; +}; + +typedef Variant<HueShiftImageOperation, SaturationShiftImageOperation, BrightnessMultiplyImageOperation, FadeToColorImageOperation, + ScanLinesImageOperation, SetColorImageOperation, ColorReplaceImageOperation, AlphaMaskImageOperation, BlendImageOperation, + MultiplyImageOperation, BorderImageOperation, ScaleImageOperation, CropImageOperation, FlipImageOperation> ImageOperation; + +ImageOperation imageOperationFromString(String const& string); +String imageOperationToString(ImageOperation const& operation); + +// Each operation is assumed to be separated by '?', with parameters +// separated by ';' or '=' +List<ImageOperation> parseImageOperations(String const& params); + +// Each operation separated by '?', returns string with leading '?' +String printImageOperations(List<ImageOperation> const& operations); + +StringList imageOperationReferences(List<ImageOperation> const& operations); + +typedef function<Image const*(String const& refName)> ImageReferenceCallback; + +Image processImageOperations(List<ImageOperation> const& operations, Image input, ImageReferenceCallback refCallback = {}); + +} + +#endif diff --git a/source/core/StarInterpolation.hpp b/source/core/StarInterpolation.hpp new file mode 100644 index 0000000..1797495 --- /dev/null +++ b/source/core/StarInterpolation.hpp @@ -0,0 +1,454 @@ +#ifndef STAR_INTERPOLATE_BASE +#define STAR_INTERPOLATE_BASE + +#include "StarMathCommon.hpp" +#include "StarArray.hpp" +#include "StarAlgorithm.hpp" + +namespace Star { + +enum class BoundMode { + Clamp, + Extrapolate, + Wrap +}; + +enum class InterpolationMode { + HalfStep, + Linear, + Cubic +}; + +template <typename T1, typename T2> +T2 angleLerp(T1 const& offset, T2 const& f0, T2 const& f1) { + return f0 + angleDiff(f0, f1) * offset; +} + +template <typename T1, typename T2> +T2 sinEase(T1 const& offset, T2 const& f0, T2 const& f1) { + T1 w = (sin(offset * Constants::pi - Constants::pi / 2) + 1) / 2; + return f0 * (1 - w) + f1 * w; +} + +template <typename T1, typename T2> +T2 lerp(T1 const& offset, T2 const& f0, T2 const& f1) { + return f0 * (1 - offset) + f1 * (offset); +} + +template <typename T1, typename T2> +T2 lerpWithLimit(Maybe<T2> const& limit, T1 const& offset, T2 const& f0, T2 const& f1) { + if (limit && abs(f1 - f0) > *limit) + return f1; + return lerp(offset, f0, f1); +} + +template <typename T1, typename T2> +T2 step(T1 threshold, T1 x, T2 a, T2 b) { + if (x < threshold) + return a; + else + return b; +} + +template <typename T1, typename T2> +T2 halfStep(T1 x, T2 a, T2 b) { + if (x < 0.5) + return a; + else + return b; +} + +template <typename T1, typename T2> +T2 cubic4(T1 const& x, T2 const& f0, T2 const& f1, T2 const& f2, T2 const& f3) { + // (-1/2 * f0 + 3/2 * f1 + -3/2 * f2 + 1/2 * f3) * x * x * x + + // ( 1 * f0 + -5/2 * f1 + 2 * f2 + -1/2 * f3) * x * x + + // (-1/2 * f0 + 0 * f1 + 1/2 * f2 + 0 * f3) * x + + // ( 0 * f0 + 1 * f1 + 0 * f2 + 0 * f3) * 1.0 + return f1 + (f2 - f0 + (f0 * 2.0 - f1 * 5.0 + f2 * 4.0 - f3 + ((f1 - f2) * 3.0 + f3 - f0) * x) * x) * x * 0.5; +} + +template <typename T1, typename T2> +T2 catmulRom4(T1 const& x, T2 const& f0, T2 const& f1, T2 const& f2, T2 const& f3) { + return ((f1 * 2) + (-f0 + f2) * x + (f0 * 2 - f1 * 5 + f2 * 4 - f3) * x * x + + (-f0 + f1 * 3 - f2 * 3 + f3) * x * x * x) + * 0.5; +} + +template <typename T1, typename T2> +T2 hermite2(T1 const& x, T2 const& a, T2 const& b) { + return a + (b - a) * x * x * (3 - 2 * x); +} + +template <typename T1, typename T2> +T2 quintic2(T1 const& x, T2 const& a, T2 const& b) { + return a + (b - a) * x * x * x * (x * (x * 6 - 15) + 10); +} + +template <typename WeightT> +struct LinearWeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 2> WeightVec; + + WeightVec operator()(Weight x) const { + return {1 - x, x}; + } +}; + +template <typename WeightT> +struct StepWeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 2> WeightVec; + + StepWeightOperator(Weight threshold = 0.5) : threshold(threshold) {} + + WeightVec operator()(Weight x) const { + if (x < threshold) + return {1, 0}; + else + return {0, 1}; + } + + Weight threshold; +}; + +template <typename WeightT> +struct SinWeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 2> WeightVec; + + WeightVec operator()(Weight x) const { + Weight w = (sin(x * Constants::pi - Constants::pi / 2) + 1) / 2; + return {1 - w, w}; + } +}; + +template <typename WeightT> +struct Hermite2WeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 2> WeightVec; + + WeightVec operator()(Weight x) const { + Weight w = x * x * (3 - 2 * x); + return {1 - w, w}; + } +}; + +template <typename WeightT> +struct Quintic2WeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 2> WeightVec; + + WeightVec operator()(Weight x) const { + Weight w = x * x * x * (x * (x * 6 - 15) + 10); + return {1 - w, w}; + } +}; + +// Setting 'LinearExtrapolate' flag to true changes the weights to be linear +// when x is outside of the range [0.0, 1.0] +template <typename WeightT> +struct Cubic4WeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 4> WeightVec; + + Cubic4WeightOperator(bool le = false) : linearExtrapolate(le) {} + + WeightVec operator()(Weight x) const { + if (linearExtrapolate && x > 1) { + return {0, 0, 2 - x, x - 1}; + } else if (linearExtrapolate && x < 0) { + return {-x, 1 + x, 0, 0}; + } else { + // (-1/2 * f0 + 3/2 * f1 + -3/2 * f2 + 1/2 * f3) * x*x*x + + // ( 1 * f0 + -5/2 * f1 + 2 * f2 + -1/2 * f3) * x*x + + // (-1/2 * f0 + 0 * f1 + 1/2 * f2 + 0 * f3) * x + + // ( 0 * f0 + 1 * f1 + 0 * f2 + 0 * f3) * 1.0 + + Weight x2 = x * x; + Weight x3 = x2 * x; + return WeightVec(-0.5 * x3 + 1 * x2 - 0.5 * x, + 1.5 * x3 + -2.5 * x2 + 1.0, + -1.5 * x3 + 2.0 * x2 + 0.5 * x, + 0.5 * x3 - 0.5 * x2); + } + } + bool linearExtrapolate; +}; + +// Setting 'LinearExtrapolate' flag to true changes the weights to be linear +// when x is outside of the range [0.0, 1.0] +template <typename WeightT> +struct Catmul4WeightOperator { + typedef WeightT Weight; + typedef Array<Weight, 4> WeightVec; + + Catmul4WeightOperator(bool le = false) : linearExtrapolate(le) {} + + WeightVec operator()(Weight x) const { + if (linearExtrapolate && x > 1) { + return {0, 0, 2 - x, x - 1}; + } else if (linearExtrapolate && x < 0) { + return {-x, 1 + x, 0, 0}; + } else { + Weight x2 = x * x; + Weight x3 = x * x * x; + return {(-x3 + x2 * 2 - x) / 2, (x3 * 3 - x2 * 5 + 2) / 2, (-x3 * 3 + x2 * 4 + x) / 2, (x3 - x2) / 2}; + } + } + + bool linearExtrapolate; +}; + +template <typename Loctype, typename IndexType> +struct Bound2 { + IndexType i0; + IndexType i1; + Loctype offset; +}; + +// loc should be in "index space", meaning that 0 points exactly to the first +// element and extent - 1 +// points exactly to the last element. +template <typename LocType, typename IndexType> +Bound2<LocType, IndexType> getBound2(LocType loc, IndexType extent, BoundMode bmode) { + Bound2<LocType, IndexType> bound; + if (extent <= 1) { + bound.i0 = bound.i1 = bound.offset = 0; + return bound; + } + + bound.offset = 0; + if (bmode == BoundMode::Wrap) { + loc = pfmod<LocType>(loc, extent); + } else { + LocType newLoc = clamp<LocType>(loc, 0, extent - 1); + if (bmode == BoundMode::Extrapolate) + bound.offset += loc - newLoc; + + loc = newLoc; + } + + bound.i0 = IndexType(loc); + + if (bound.i0 == extent - 1) { + if (bmode == BoundMode::Wrap) { + bound.i1 = 0; + } else { + bound.i1 = bound.i0; + bound.i0 -= 1; + } + } else { + bound.i1 = bound.i0 + 1; + } + + bound.offset += loc - bound.i0; + + return bound; +} + +template <typename Loctype, typename IndexType> +struct Bound4 { + Bound4() {} + IndexType i0; + IndexType i1; + IndexType i2; + IndexType i3; + Loctype offset; +}; + +// loc should be in "index space", meaning that 0 points exactly to the first +// element and extent - 1 +// points exactly to the last element. +template <typename LocType, typename IndexType> +Bound4<LocType, IndexType> getBound4(LocType loc, IndexType extent, BoundMode bmode) { + Bound4<LocType, IndexType> bound; + if (extent <= 1) { + bound.i0 = bound.i1 = bound.i2 = bound.i3 = bound.offset = 0; + return bound; + } + + bound.offset = 0; + if (bmode == BoundMode::Wrap) { + loc = pfmod<LocType>(loc, extent); + } else { + LocType newLoc = clamp<LocType>(loc, 0, extent - 1); + if (bmode == BoundMode::Extrapolate) + bound.offset += loc - newLoc; + + loc = newLoc; + } + + bound.i1 = IndexType(loc); + + if (bound.i1 == extent - 1) { + if (bmode == BoundMode::Wrap) { + bound.i0 = bound.i1 - 1; + bound.i2 = 0; + bound.i3 = 1; + } else { + bound.i1 = bound.i1 - 2; + + bound.i0 = bound.i1 - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = bound.i2 + 1; + } + } else if (bound.i1 == extent - 2) { + if (bmode == BoundMode::Wrap) { + bound.i0 = bound.i1 - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = 0; + } else { + bound.i1 = bound.i1 - 1; + + bound.i0 = bound.i1 - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = bound.i2 + 1; + } + } else if (bound.i1 == 0) { + if (bmode == BoundMode::Wrap) { + bound.i0 = extent - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = bound.i2 + 1; + } else { + bound.i1 = bound.i1 + 1; + + bound.i0 = bound.i1 - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = bound.i2 + 1; + } + } else { + bound.i0 = bound.i1 - 1; + bound.i2 = bound.i1 + 1; + bound.i3 = bound.i1 + 2; + } + + bound.offset += loc - bound.i1; + + return bound; +} + +template <typename Container, typename Pos, typename WeightOp> +typename Container::value_type listInterpolate2( + Container const& cont, Pos x, WeightOp weightOp, BoundMode bmode = BoundMode::Clamp) { + if (cont.size() == 0) { + return typename Container::value_type(); + } else if (cont.size() == 1) { + return cont[0]; + } else { + auto bound = getBound2(x, cont.size(), bmode); + auto weights = weightOp(bound.offset); + return cont[bound.i0] * weights[0] + cont[bound.i1] * weights[1]; + } +} + +template <typename Container, typename Pos, typename WeightOp> +typename Container::value_type listInterpolate4( + Container const& cont, Pos x, WeightOp weightOp, BoundMode bmode = BoundMode::Clamp) { + if (cont.size() == 0) { + return typename Container::value_type(); + } else if (cont.size() == 1) { + return cont[0]; + } else { + auto bound = getBound4(x, cont.size(), bmode); + auto weights = weightOp(bound.offset); + return cont[bound.i0] * weights[0] + cont[bound.i1] * weights[1] + cont[bound.i2] * weights[2] + + cont[bound.i3] * weights[3]; + } +} + +// Returns an index value (not integer) that represents the value that, if +// passed in as an index to a simple linear interpolation of the given +// container, would yield the given value. (In other words, this goes from +// function space to index space on a list of points). Useful for doing +// interpolation on functions that are unevenly spaced. Given container must +// be sorted. If there is an ambiguity on points due to repeat points, will +// choose the lower-most of the points. +template <typename Iterator, typename Pos, typename Comp, typename PosGetter> +Pos inverseLinearInterpolateLower(Iterator begin, Iterator end, Pos t, Comp&& comp, PosGetter&& posGetter) { + // Container must be at least size 2 for this to make sense. + if (begin == end || std::next(begin) == end) + return Pos(); + + Iterator i = std::lower_bound(std::next(begin), std::prev(end), t, forward<Comp>(comp)); + + --i; + Pos min = posGetter(*i); + Pos max = posGetter(*(++i)); + Pos ipos = Pos(std::distance(begin, --i)); + + Pos dist = max - min; + if (dist == 0) + return ipos; + else + return ipos + (t - min) / dist; +} + +template <typename Iterator, typename Pos> +Pos inverseLinearInterpolateLower(Iterator begin, Iterator end, Pos t) { + return inverseLinearInterpolateLower(begin, end, t, std::less<Pos>(), identity()); +} + +// Same as inverseLinearInterpolateLower, except chooses the upper most of the +// points in the ambiguous case. +template <typename Iterator, typename Pos, typename Comp, typename PosGetter> +Pos inverseLinearInterpolateUpper(Iterator begin, Iterator end, Pos t, Comp&& comp, PosGetter&& posGetter) { + // Container must be at least size 2 for this to make sense. + if (begin == end || std::next(begin) == end) + return Pos(); + + Iterator i = std::upper_bound(std::next(begin), std::prev(end), t, forward<Comp>(comp)); + + --i; + Pos min = posGetter(*i); + Pos max = posGetter(*(++i)); + Pos ipos = Pos(std::distance(begin, --i)); + + Pos dist = max - min; + if (dist == 0) + return ipos + 1; + else + return ipos + (t - min) / dist; +} + +template <typename Iterator, typename Pos> +Pos inverseLinearInterpolateUpper(Iterator begin, Iterator end, Pos t) { + return inverseLinearInterpolateUpper(begin, end, t, std::less<Pos>(), identity()); +} + +template <typename XContainer, typename YContainer, typename PositionType, typename WeightOp> +typename YContainer::value_type parametricInterpolate2(XContainer const& xvals, + YContainer const& yvals, + PositionType const& position, + WeightOp weightOp, + BoundMode bmode) { + starAssert(xvals.size() != 0); + starAssert(xvals.size() == yvals.size()); + + if (yvals.size() == 1) + return yvals[0]; + + PositionType ipos = inverseLinearInterpolateLower(xvals.begin(), xvals.end(), position); + + return listInterpolate2(yvals, ipos, weightOp, bmode); +} + +template <typename XContainer, typename YContainer, typename PositionType, typename WeightOp> +typename YContainer::value_type parametricInterpolate4(XContainer const& xvals, + YContainer const& yvals, + PositionType const& position, + WeightOp weightOp, + BoundMode bmode) { + starAssert(xvals.size() != 0); + starAssert(xvals.size() == yvals.size()); + + if (yvals.size() == 1) + return yvals[0]; + + PositionType ipos = inverseLinearInterpolateLower(xvals.begin(), xvals.end(), position); + + return listInterpolate4(yvals, ipos, weightOp, bmode); +} + +} + +#endif diff --git a/source/core/StarIterator.hpp b/source/core/StarIterator.hpp new file mode 100644 index 0000000..7770d84 --- /dev/null +++ b/source/core/StarIterator.hpp @@ -0,0 +1,437 @@ +#ifndef STAR_ITERATOR_H +#define STAR_ITERATOR_H + +#include <algorithm> + +#include "StarException.hpp" + +namespace Star { + +STAR_EXCEPTION(IteratorException, StarException); + +// Provides java style iterators for bidirectional list-like containers +// (SIterator and SMutableIterator) and forward only map-like containers +// (SMapIterator and SMutableMapIterator) +template <typename Container> +class SIterator { +public: + typedef typename Container::const_iterator iterator; + typedef decltype(*iterator()) value_ref; + + SIterator(Container const& c) : cont(c) { + toFront(); + } + + void toFront() { + curr = cont.begin(); + direction = 0; + } + + void toBack() { + curr = cont.end(); + direction = 0; + } + + bool hasNext() const { + return curr != cont.end(); + } + + bool hasPrevious() const { + return curr != cont.begin(); + } + + value_ref value() const { + if (direction == 1) { + if (curr != cont.end() && cont.size() != 0) + return *curr; + else + throw IteratorException("value() called on end()"); + } else if (direction == -1) { + if (curr != cont.begin() && cont.size() != 0) { + iterator back = curr; + return *(--back); + } else { + throw IteratorException("value() called on begin()"); + } + } else { + throw IteratorException("value() called without previous next() or previous()"); + } + } + + value_ref next() { + if (hasNext()) { + direction = -1; + return *(curr++); + } + throw IteratorException("next() called on end"); + } + + value_ref previous() { + if (hasPrevious()) { + direction = 1; + return *(--curr); + } + throw IteratorException("prev() called on beginning"); + } + + value_ref peekNext() const { + SIterator t = *this; + return t.next(); + } + + value_ref peekPrevious() const { + SIterator t = *this; + return t.previous(); + } + + size_t distFront() const { + return std::distance(cont.begin(), curr); + } + + size_t distBack() const { + return std::distance(curr, cont.end()); + } + +private: + SIterator& operator=(iterator const& i) { + return iterator::operator=(i); + } + Container const& cont; + iterator curr; + + int direction; +}; + +template <typename Container> +SIterator<Container> makeSIterator(Container const& c) { + return SIterator<Container>(c); +} + +template <typename Container> +class SMutableIterator { +public: + typedef typename Container::value_type value_type; + typedef typename Container::iterator iterator; + typedef decltype(*iterator()) value_ref; + + SMutableIterator(Container& c) : cont(c) { + toFront(); + } + + void toFront() { + curr = cont.begin(); + direction = 0; + } + + void toBack() { + curr = cont.end(); + direction = 0; + } + + bool hasNext() const { + return curr != cont.end(); + } + + bool hasPrevious() const { + return curr != cont.begin(); + } + + void insert(value_type v) { + curr = ++cont.insert(curr, move(v)); + direction = -1; + } + + void remove() { + if (direction == 1) { + direction = 0; + if (curr != cont.end() && cont.size() != 0) + curr = cont.erase(curr); + else + throw IteratorException("remove() called on end()"); + } else if (direction == -1) { + direction = 0; + if (curr != cont.begin() && cont.size() != 0) + curr = cont.erase(--curr); + else + throw IteratorException("remove() called on begin()"); + } else { + throw IteratorException("remove() called without previous next() or previous()"); + } + } + + value_ref value() const { + if (direction == 1) { + if (curr != cont.end() && cont.size() != 0) + return *curr; + else + throw IteratorException("value() called on end()"); + } else if (direction == -1) { + if (curr != cont.begin() && cont.size() != 0) { + iterator back = curr; + return *(--back); + } else { + throw IteratorException("value() called on begin()"); + } + } else { + throw IteratorException("value() called without previous next() or previous()"); + } + } + + void setValue(value_type v) const { + value() = move(v); + } + + value_ref next() { + if (curr == cont.end()) + throw IteratorException("next() called on end"); + direction = -1; + return *curr++; + } + + value_ref previous() { + if (curr == cont.begin()) + throw IteratorException("previous() called on begin"); + direction = 1; + return *--curr; + } + + value_ref peekNext() const { + SMutableIterator n = *this; + return n.next(); + } + + value_ref peekPrevious() const { + SMutableIterator n = *this; + return n.previous(); + } + + size_t distFront() const { + return std::distance(cont.begin(), curr); + } + + size_t distBack() const { + return std::distance(curr, cont.end()); + } + +private: + SMutableIterator& operator=(iterator const& i) { + return iterator::operator=(i); + } + + Container& cont; + iterator curr; + + // -1 means remove() will remove --cur, +1 means ++cur, 0 means remove() is + // invalid. + int direction; +}; + +template <typename Container> +SMutableIterator<Container> makeSMutableIterator(Container& c) { + return SMutableIterator<Container>(c); +} + +template <typename Container> +class SMapIterator { +public: + typedef typename Container::key_type key_type; + typedef typename Container::mapped_type mapped_type; + + typedef typename Container::const_iterator iterator; + typedef decltype(*iterator()) value_ref; + + SMapIterator(Container const& c) : cont(c) { + toFront(); + } + + void toFront() { + curr = cont.end(); + } + + void toBack() { + curr = cont.end(); + if (curr != cont.begin()) + --curr; + } + + bool hasNext() const { + iterator end = cont.end(); + if (curr == end) + return cont.begin() != end; + else + return ++iterator(curr) != end; + } + + key_type const& key() const { + if (curr != cont.end()) { + return curr->first; + } else { + throw IteratorException("key() called on begin()"); + } + } + + mapped_type const& value() const { + if (curr != cont.end()) { + return curr->second; + } else { + throw IteratorException("value() called on begin()"); + } + } + + value_ref const& next() { + if (hasNext()) { + if (curr == cont.end()) + curr = cont.begin(); + else + ++curr; + return *curr; + } + throw IteratorException("next() called on end"); + } + + value_ref peekNext() const { + SMapIterator t = *this; + return t.next(); + } + + size_t distFront() const { + return std::distance(cont.begin(), curr); + } + + size_t distBack() const { + return std::distance(curr, cont.end()) - 1; + } + +protected: + SMapIterator& operator=(iterator const& i) { + return iterator::operator=(i); + } + Container const& cont; + iterator curr; +}; + +template <typename Container> +SMapIterator<Container> makeSMapIterator(Container const& c) { + return SMapIterator<Container>(c); +} + +template <typename Container> +class SMutableMapIterator { +public: + typedef typename Container::key_type key_type; + typedef typename Container::mapped_type mapped_type; + + typedef typename Container::iterator iterator; + typedef decltype(*iterator()) value_ref; + + SMutableMapIterator(Container& c) : cont(c) { + toFront(); + } + + void toFront() { + curr = cont.end(); + remCalled = false; + } + + void toBack() { + curr = cont.end(); + if (curr != cont.begin()) + --curr; + } + + bool hasNext() const { + iterator end = cont.end(); + if (curr == end) + return cont.begin() != end && !remCalled; + else if (remCalled) + return curr != end; + else + return ++iterator(curr) != end; + } + + key_type const& key() const { + if (remCalled) + throw IteratorException("key() called after remove()"); + else if (curr != cont.end()) + return curr->first; + else + throw IteratorException("key() called on begin()"); + } + + mapped_type& value() const { + if (remCalled) + throw IteratorException("value() called after remove()"); + else if (curr != cont.end()) + return curr->second; + else + throw IteratorException("value() called on begin()"); + } + + value_ref next() { + if (hasNext()) { + if (curr == cont.end()) + curr = cont.begin(); + else if (remCalled) + remCalled = false; + else + ++curr; + + return *curr; + } else { + throw IteratorException("next() called on end"); + } + } + + value_ref peekNext() const { + SMutableMapIterator t = *this; + return t.next(); + } + + void remove() { + if (remCalled) { + throw IteratorException("remove() called twice"); + } else if (curr == cont.end()) { + throw IteratorException("remove() called at front"); + } else { + if (curr == cont.begin()) { + cont.erase(curr); + curr = cont.end(); + } else { + curr = cont.erase(curr); + remCalled = true; + } + } + } + + size_t distFront() const { + if (curr == cont.end()) + return 0; + else + return std::distance(cont.begin(), curr) - (remCalled ? 1 : 0); + } + + size_t distBack() const { + if (curr == cont.end()) + return cont.size(); + else + return std::distance(curr, cont.end()) - 1 + (remCalled ? 1 : 0); + } + +private: + SMutableMapIterator& operator=(iterator const& i) { + return iterator::operator=(i); + } + + Container& cont; + iterator curr; + bool remCalled; +}; + +template <typename Container> +SMutableMapIterator<Container> makeSMutableMapIterator(Container& c) { + return SMutableMapIterator<Container>(c); +} + +} + +#endif diff --git a/source/core/StarJson.cpp b/source/core/StarJson.cpp new file mode 100644 index 0000000..7ef45ff --- /dev/null +++ b/source/core/StarJson.cpp @@ -0,0 +1,1015 @@ +#include "StarJson.hpp" +#include "StarJsonBuilder.hpp" +#include "StarJsonPath.hpp" +#include "StarFormat.hpp" +#include "StarLexicalCast.hpp" +#include "StarIterator.hpp" +#include "StarFile.hpp" + +namespace Star { + +Json::Type Json::typeFromName(String const& t) { + if (t == "float") + return Type::Float; + else if (t == "bool") + return Type::Bool; + else if (t == "int") + return Type::Int; + else if (t == "string") + return Type::String; + else if (t == "array") + return Type::Array; + else if (t == "object") + return Type::Object; + else if (t == "null") + return Type::Null; + else + throw JsonException(strf("String '%s' is not a valid json type", t)); +} + +String Json::typeName(Type t) { + switch (t) { + case Type::Float: + return "float"; + case Type::Bool: + return "bool"; + case Type::Int: + return "int"; + case Type::String: + return "string"; + case Type::Array: + return "array"; + case Type::Object: + return "object"; + default: + return "null"; + } +} + +bool Json::operator==(const Json& v) const { + if (type() == Type::Null && v.type() == Type::Null) { + return true; + } else if (type() != v.type()) { + if ((type() == Type::Float || type() == Type::Int) && (v.type() == Type::Float || v.type() == Type::Int)) + return toDouble() == v.toDouble() && toInt() == v.toInt(); + return false; + } else { + if (type() == Type::Float) + return m_data.get<double>() == v.m_data.get<double>(); + else if (type() == Type::Bool) + return m_data.get<bool>() == v.m_data.get<bool>(); + else if (type() == Type::Int) + return m_data.get<int64_t>() == v.m_data.get<int64_t>(); + else if (type() == Type::String) + return *m_data.get<StringConstPtr>() == *v.m_data.get<StringConstPtr>(); + else if (type() == Type::Array) + return *m_data.get<JsonArrayConstPtr>() == *v.m_data.get<JsonArrayConstPtr>(); + else if (type() == Type::Object) + return *m_data.get<JsonObjectConstPtr>() == *v.m_data.get<JsonObjectConstPtr>(); + } + return false; +} + +bool Json::operator!=(const Json& v) const { + return !(*this == v); +} + +bool Json::unique() const { + if (m_data.is<StringConstPtr>()) + return m_data.get<StringConstPtr>().unique(); + else if (m_data.is<JsonArrayConstPtr>()) + return m_data.get<JsonArrayConstPtr>().unique(); + else if (m_data.is<JsonObjectConstPtr>()) + return m_data.get<JsonObjectConstPtr>().unique(); + else + return true; +} + +Json Json::ofType(Type t) { + switch (t) { + case Type::Float: + return Json(0.0); + case Type::Bool: + return Json(false); + case Type::Int: + return Json(0); + case Type::String: + return Json(""); + case Type::Array: + return Json(JsonArray()); + case Type::Object: + return Json(JsonObject()); + default: + return Json(); + } +} + +Json Json::parse(String const& string) { + return inputUtf32Json<String::const_iterator>(string.begin(), string.end(), true); +} + +Json Json::parseJson(String const& json) { + return inputUtf32Json<String::const_iterator>(json.begin(), json.end(), false); +} + +Json::Json() {} + +Json::Json(double d) { + m_data = d; +} + +Json::Json(bool b) { + m_data = b; +} + +Json::Json(int i) { + m_data = (int64_t)i; +} + +Json::Json(long i) { + m_data = (int64_t)i; +} + +Json::Json(long long i) { + m_data = (int64_t)i; +} + +Json::Json(unsigned int i) { + m_data = (int64_t)i; +} + +Json::Json(unsigned long i) { + m_data = (int64_t)i; +} + +Json::Json(unsigned long long i) { + m_data = (int64_t)i; +} + +Json::Json(char const* s) { + m_data = make_shared<String const>(s); +} + +Json::Json(String::Char const* s) { + m_data = make_shared<String const>(s); +} + +Json::Json(String::Char const* s, size_t len) { + m_data = make_shared<String const>(s, len); +} + +Json::Json(String s) { + m_data = make_shared<String const>(std::move(s)); +} + +Json::Json(std::string s) { + m_data = make_shared<String const>((std::move(s))); +} + +Json::Json(JsonArray l) { + m_data = make_shared<JsonArray const>(std::move(l)); +} + +Json::Json(JsonObject m) { + m_data = make_shared<JsonObject const>(std::move(m)); +} + +double Json::toDouble() const { + if (type() == Type::Float) + return m_data.get<double>(); + if (type() == Type::Int) + return (double)m_data.get<int64_t>(); + + throw JsonException::format("Improper conversion to double from %s", typeName()); +} + +float Json::toFloat() const { + return (float)toDouble(); +} + +bool Json::toBool() const { + if (type() != Type::Bool) + throw JsonException::format("Improper conversion to bool from %s", typeName()); + return m_data.get<bool>(); +} + +int64_t Json::toInt() const { + if (type() == Type::Float) { + return (int64_t)m_data.get<double>(); + } else if (type() == Type::Int) { + return m_data.get<int64_t>(); + } else { + throw JsonException::format("Improper conversion to int from %s", typeName()); + } +} + +uint64_t Json::toUInt() const { + if (type() == Type::Float) { + return (uint64_t)m_data.get<double>(); + } else if (type() == Type::Int) { + return (uint64_t)m_data.get<int64_t>(); + } else { + throw JsonException::format("Improper conversion to unsigned int from %s", typeName()); + } +} + +String Json::toString() const { + if (type() != Type::String) + throw JsonException(strf("Cannot convert from %s to string", typeName())); + return *m_data.get<StringConstPtr>(); +} + +JsonArray Json::toArray() const { + if (type() != Type::Array) + throw JsonException::format("Improper conversion to JsonArray from %s", typeName()); + return *m_data.get<JsonArrayConstPtr>(); +} + +JsonObject Json::toObject() const { + if (type() != Type::Object) + throw JsonException::format("Improper conversion to JsonObject from %s", typeName()); + return *m_data.get<JsonObjectConstPtr>(); +} + +StringConstPtr Json::stringPtr() const { + if (type() != Type::String) + throw JsonException(strf("Cannot convert from %s to string", typeName())); + return m_data.get<StringConstPtr>(); +} + +JsonArrayConstPtr Json::arrayPtr() const { + if (type() != Type::Array) + throw JsonException::format("Improper conversion to JsonArray from %s", typeName()); + return m_data.get<JsonArrayConstPtr>(); +} + +JsonObjectConstPtr Json::objectPtr() const { + if (type() != Type::Object) + throw JsonException::format("Improper conversion to JsonObject from %s", typeName()); + return m_data.get<JsonObjectConstPtr>(); +} + +Json::IteratorWrapper<JsonArray> Json::iterateArray() const { + return IteratorWrapper<JsonArray>{arrayPtr()}; +} + +Json::IteratorWrapper<JsonObject> Json::iterateObject() const { + return IteratorWrapper<JsonObject>{objectPtr()}; +} + +Maybe<Json> Json::opt() const { + if (isNull()) + return {}; + return *this; +} + +Maybe<double> Json::optDouble() const { + if (isNull()) + return {}; + return toDouble(); +} + +Maybe<float> Json::optFloat() const { + if (isNull()) + return {}; + return toFloat(); +} + +Maybe<bool> Json::optBool() const { + if (isNull()) + return {}; + return toBool(); +} + +Maybe<int64_t> Json::optInt() const { + if (isNull()) + return {}; + return toInt(); +} + +Maybe<uint64_t> Json::optUInt() const { + if (isNull()) + return {}; + return toUInt(); +} + +Maybe<String> Json::optString() const { + if (isNull()) + return {}; + return toString(); +} + +Maybe<JsonArray> Json::optArray() const { + if (isNull()) + return {}; + return toArray(); +} + +Maybe<JsonObject> Json::optObject() const { + if (isNull()) + return {}; + return toObject(); +} + +size_t Json::size() const { + if (type() == Type::Array) + return m_data.get<JsonArrayConstPtr>()->size(); + else if (type() == Type::Object) + return m_data.get<JsonObjectConstPtr>()->size(); + else + throw JsonException("size() called on improper json type"); +} + +bool Json::contains(String const& key) const { + if (type() == Type::Object) + return m_data.get<JsonObjectConstPtr>()->contains(key); + else + throw JsonException("contains() called on improper json type"); +} + +Json Json::get(size_t index) const { + if (auto p = ptr(index)) + return *p; + throw JsonException(strf("Json::get(%s) out of range", index)); +} + +double Json::getDouble(size_t index) const { + return get(index).toDouble(); +} + +float Json::getFloat(size_t index) const { + return get(index).toFloat(); +} + +bool Json::getBool(size_t index) const { + return get(index).toBool(); +} + +int64_t Json::getInt(size_t index) const { + return get(index).toInt(); +} + +uint64_t Json::getUInt(size_t index) const { + return get(index).toUInt(); +} + +String Json::getString(size_t index) const { + return get(index).toString(); +} + +JsonArray Json::getArray(size_t index) const { + return get(index).toArray(); +} + +JsonObject Json::getObject(size_t index) const { + return get(index).toObject(); +} + +Json Json::get(size_t index, Json def) const { + if (auto p = ptr(index)) + return *p; + return def; +} + +double Json::getDouble(size_t index, double def) const { + if (auto p = ptr(index)) + return p->toDouble(); + return def; +} + +float Json::getFloat(size_t index, float def) const { + if (auto p = ptr(index)) + return p->toFloat(); + return def; +} + +bool Json::getBool(size_t index, bool def) const { + if (auto p = ptr(index)) + return p->toBool(); + return def; +} + +int64_t Json::getInt(size_t index, int64_t def) const { + if (auto p = ptr(index)) + return p->toInt(); + return def; +} + +uint64_t Json::getUInt(size_t index, int64_t def) const { + if (auto p = ptr(index)) + return p->toUInt(); + return def; +} + +String Json::getString(size_t index, String def) const { + if (auto p = ptr(index)) + return p->toString(); + return def; +} + +JsonArray Json::getArray(size_t index, JsonArray def) const { + if (auto p = ptr(index)) + return p->toArray(); + return def; +} + +JsonObject Json::getObject(size_t index, JsonObject def) const { + if (auto p = ptr(index)) + return p->toObject(); + return def; +} + +Json Json::get(String const& key) const { + if (auto p = ptr(key)) + return *p; + throw JsonException(strf("No such key in Json::get(\"%s\")", key)); +} + +double Json::getDouble(String const& key) const { + return get(key).toDouble(); +} + +float Json::getFloat(String const& key) const { + return get(key).toFloat(); +} + +bool Json::getBool(String const& key) const { + return get(key).toBool(); +} + +int64_t Json::getInt(String const& key) const { + return get(key).toInt(); +} + +uint64_t Json::getUInt(String const& key) const { + return get(key).toUInt(); +} + +String Json::getString(String const& key) const { + return get(key).toString(); +} + +JsonArray Json::getArray(String const& key) const { + return get(key).toArray(); +} + +JsonObject Json::getObject(String const& key) const { + return get(key).toObject(); +} + +Json Json::get(String const& key, Json def) const { + if (auto p = ptr(key)) + return *p; + return def; +} + +double Json::getDouble(String const& key, double def) const { + auto p = ptr(key); + if (p && *p) + return p->toDouble(); + return def; +} + +float Json::getFloat(String const& key, float def) const { + auto p = ptr(key); + if (p && *p) + return p->toFloat(); + return def; +} + +bool Json::getBool(String const& key, bool def) const { + auto p = ptr(key); + if (p && *p) + return p->toBool(); + return def; +} + +int64_t Json::getInt(String const& key, int64_t def) const { + auto p = ptr(key); + if (p && *p) + return p->toInt(); + return def; +} + +uint64_t Json::getUInt(String const& key, int64_t def) const { + auto p = ptr(key); + if (p && *p) + return p->toUInt(); + return def; +} + +String Json::getString(String const& key, String def) const { + auto p = ptr(key); + if (p && *p) + return p->toString(); + return def; +} + +JsonArray Json::getArray(String const& key, JsonArray def) const { + auto p = ptr(key); + if (p && *p) + return p->toArray(); + return def; +} + +JsonObject Json::getObject(String const& key, JsonObject def) const { + auto p = ptr(key); + if (p && *p) + return p->toObject(); + return def; +} + +Maybe<Json> Json::opt(String const& key) const { + auto p = ptr(key); + if (p && *p) + return *p; + return {}; +} + +Maybe<double> Json::optDouble(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toDouble(); + return {}; +} + +Maybe<float> Json::optFloat(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toFloat(); + return {}; +} + +Maybe<bool> Json::optBool(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toBool(); + return {}; +} + +Maybe<int64_t> Json::optInt(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toInt(); + return {}; +} + +Maybe<uint64_t> Json::optUInt(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toUInt(); + return {}; +} + +Maybe<String> Json::optString(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toString(); + return {}; +} + +Maybe<JsonArray> Json::optArray(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toArray(); + return {}; +} + +Maybe<JsonObject> Json::optObject(String const& key) const { + auto p = ptr(key); + if (p && *p) + return p->toObject(); + return {}; +} + +Json Json::query(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q); +} + +double Json::queryDouble(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toDouble(); +} + +float Json::queryFloat(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toFloat(); +} + +bool Json::queryBool(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toBool(); +} + +int64_t Json::queryInt(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toInt(); +} + +uint64_t Json::queryUInt(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toUInt(); +} + +String Json::queryString(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toString(); +} + +JsonArray Json::queryArray(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toArray(); +} + +JsonObject Json::queryObject(String const& q) const { + return JsonPath::pathGet(*this, JsonPath::parseQueryPath, q).toObject(); +} + +Json Json::query(String const& query, Json def) const { + if (auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query)) + return *json; + return def; +} + +double Json::queryDouble(String const& query, double def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toDouble(); + return def; +} + +float Json::queryFloat(String const& query, float def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toFloat(); + return def; +} + +bool Json::queryBool(String const& query, bool def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toBool(); + return def; +} + +int64_t Json::queryInt(String const& query, int64_t def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toInt(); + return def; +} + +uint64_t Json::queryUInt(String const& query, uint64_t def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toUInt(); + return def; +} + +String Json::queryString(String const& query, String def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toString(); + return def; +} + +JsonArray Json::queryArray(String const& query, JsonArray def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toArray(); + return def; +} + +JsonObject Json::queryObject(String const& query, JsonObject def) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, query); + if (json && *json) + return json->toObject(); + return def; +} + +Maybe<Json> Json::optQuery(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return *json; + return {}; +} + +Maybe<double> Json::optQueryDouble(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toDouble(); + return {}; +} + +Maybe<float> Json::optQueryFloat(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toFloat(); + return {}; +} + +Maybe<bool> Json::optQueryBool(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toBool(); + return {}; +} + +Maybe<int64_t> Json::optQueryInt(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toInt(); + return {}; +} + +Maybe<uint64_t> Json::optQueryUInt(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toUInt(); + return {}; +} + +Maybe<String> Json::optQueryString(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toString(); + return {}; +} + +Maybe<JsonArray> Json::optQueryArray(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toArray(); + return {}; +} + +Maybe<JsonObject> Json::optQueryObject(String const& path) const { + auto json = JsonPath::pathFind(*this, JsonPath::parseQueryPath, path); + if (json && *json) + return json->toObject(); + return {}; +} + +Json Json::set(String key, Json value) const { + auto map = toObject(); + map[move(key)] = move(value); + return map; +} + +Json Json::setPath(String path, Json value) const { + return JsonPath::pathSet(*this, JsonPath::parseQueryPath, path, value); +} + +Json Json::erasePath(String path) const { + return JsonPath::pathRemove(*this, JsonPath::parseQueryPath, path); +} + +Json Json::setAll(JsonObject values) const { + auto map = toObject(); + for (auto& p : values) + map[move(p.first)] = move(p.second); + return map; +} + +Json Json::eraseKey(String key) const { + auto map = toObject(); + map.erase(move(key)); + return map; +} + +Json Json::set(size_t index, Json value) const { + auto array = toArray(); + array[index] = move(value); + return array; +} + +Json Json::insert(size_t index, Json value) const { + auto array = toArray(); + array.insertAt(index, move(value)); + return array; +} + +Json Json::append(Json value) const { + auto array = toArray(); + array.append(move(value)); + return array; +} + +Json Json::eraseIndex(size_t index) const { + auto array = toArray(); + array.eraseAt(index); + return array; +} + +Json::Type Json::type() const { + return (Type)m_data.typeIndex(); +} + +String Json::typeName() const { + return typeName(type()); +} + +Json Json::convert(Type u) const { + if (type() == u) + return *this; + + switch (u) { + case Type::Null: + return Json(); + case Type::Float: + return toDouble(); + case Type::Bool: + return toBool(); + case Type::Int: + return toInt(); + case Type::String: + return toString(); + case Type::Array: + return toArray(); + case Type::Object: + return toObject(); + default: + throw JsonException::format("Improper conversion to type %s", typeName(u)); + } +} + +bool Json::isType(Type t) const { + return type() == t; +} + +bool Json::canConvert(Type t) const { + if (type() == t) + return true; + + if (t == Type::Null) + return true; + + if ((type() == Type::Float || type() == Type::Int) && (t == Type::Float || t == Type::Int)) + return true; + + return false; +} + +bool Json::isNull() const { + return type() == Type::Null; +} + +Json::operator bool() const { + return !isNull(); +} + +String Json::repr(int pretty, bool sort) const { + String result; + outputUtf32Json(*this, std::back_inserter(result), pretty, sort); + return result; +} + +String Json::printJson(int pretty, bool sort) const { + if (type() != Type::Object && type() != Type::Array) + throw JsonException("printJson called on non-top-level JSON type"); + + return repr(pretty, sort); +} + +std::ostream& operator<<(std::ostream& os, Json const& v) { + outputUtf8Json(v, std::ostream_iterator<char>(os), 0, false); + return os; +} + +std::ostream& operator<<(std::ostream& os, JsonObject const& v) { + // Blargh copy + os << Json(v); + return os; +} + +DataStream& operator<<(DataStream& os, const Json& v) { + // Compatibility with old serialization, 0 was INVALID but INVALID is no + // longer used. + os.write<uint8_t>((uint8_t)v.type() + 1); + + if (v.type() == Json::Type::Float) { + os.write<double>(v.toDouble()); + } else if (v.type() == Json::Type::Bool) { + os.write<bool>(v.toBool()); + } else if (v.type() == Json::Type::Int) { + os.writeVlqI(v.toInt()); + } else if (v.type() == Json::Type::String) { + os.write<String>(v.toString()); + } else if (v.type() == Json::Type::Array) { + auto const& l = v.toArray(); + os.writeVlqU(l.size()); + for (auto const& v : l) + os.write<Json>(v); + } else if (v.type() == Json::Type::Object) { + auto const& m = v.toObject(); + os.writeVlqU(m.size()); + for (auto const& v : m) { + os.write<String>(v.first); + os.write<Json>(v.second); + } + } + return os; +} + +DataStream& operator>>(DataStream& os, Json& v) { + // Compatibility with old serialization, 0 was INVALID but INVALID is no + // longer used. + uint8_t typeIndex = os.read<uint8_t>(); + if (typeIndex > 0) + typeIndex -= 1; + + Json::Type type = (Json::Type)typeIndex; + + if (type == Json::Type::Float) { + v = Json(os.read<double>()); + } else if (type == Json::Type::Bool) { + v = Json(os.read<bool>()); + } else if (type == Json::Type::Int) { + v = Json(os.readVlqI()); + } else if (type == Json::Type::String) { + v = Json(os.read<String>()); + } else if (type == Json::Type::Array) { + JsonArray l; + size_t s = os.readVlqU(); + for (size_t i = 0; i < s; ++i) + l.append(os.read<Json>()); + + v = move(l); + } else if (type == Json::Type::Object) { + JsonObject m; + size_t s = os.readVlqU(); + for (size_t i = 0; i < s; ++i) { + String k = os.read<String>(); + m[k] = os.read<Json>(); + } + + v = move(m); + } + + return os; +} + +DataStream& operator<<(DataStream& ds, JsonArray const& l) { + ds.writeContainer(l); + return ds; +} + +DataStream& operator>>(DataStream& ds, JsonArray& l) { + ds.readContainer(l); + return ds; +} + +DataStream& operator<<(DataStream& ds, JsonObject const& m) { + ds.writeMapContainer(m); + return ds; +} + +DataStream& operator>>(DataStream& ds, JsonObject& m) { + ds.readMapContainer(m); + return ds; +} + +size_t hash<Json>::operator()(Json const& v) const { + // This is probably a bit slow and weird, using the utf-8 output printer to + // produce a Json hash. + + size_t h = 0; + auto collector = [&h](char c) { h = h * 101 + c; }; + outputUtf8Json(v, makeFunctionOutputIterator(collector), 0, true); + return h; +} + +Json const* Json::ptr(size_t index) const { + if (type() != Type::Array) + throw JsonException::format("Cannot call get with index on Json type %s, must be Array type", typeName()); + + auto const& list = *m_data.get<JsonArrayConstPtr>(); + if (index >= list.size()) + return nullptr; + return &list[index]; +} + +Json const* Json::ptr(String const& key) const { + if (type() != Type::Object) + throw JsonException::format("Cannot call get with key on Json type %s, must be Object type", typeName()); + auto const& map = m_data.get<JsonObjectConstPtr>(); + + auto i = map->find(key); + if (i == map->end()) + return nullptr; + + return &i->second; +} + +Json jsonMerge(Json const& base, Json const& merger) { + if (base.type() == Json::Type::Object && merger.type() == Json::Type::Object) { + JsonObject merged = base.toObject(); + for (auto const& p : merger.toObject()) { + auto res = merged.insert(p); + if (!res.second) + res.first->second = jsonMerge(res.first->second, p.second); + } + return move(merged); + + } else if (merger.type() == Json::Type::Null) { + return base; + + } else { + return merger; + } +} + +} diff --git a/source/core/StarJson.hpp b/source/core/StarJson.hpp new file mode 100644 index 0000000..f14fbf8 --- /dev/null +++ b/source/core/StarJson.hpp @@ -0,0 +1,358 @@ +#ifndef STAR_JSON_HPP +#define STAR_JSON_HPP + +#include "StarDataStream.hpp" +#include "StarVariant.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_EXCEPTION(JsonException, StarException); +STAR_EXCEPTION(JsonParsingException, StarException); + +STAR_CLASS(Json); + +typedef List<Json> JsonArray; +typedef shared_ptr<JsonArray const> JsonArrayConstPtr; + +typedef StringMap<Json> JsonObject; +typedef shared_ptr<JsonObject const> JsonObjectConstPtr; + +// Class for holding representation of JSON data. Immutable and implicitly +// shared. +class Json { +public: + template <typename Container> + struct IteratorWrapper { + typedef typename Container::const_iterator const_iterator; + typedef const_iterator iterator; + + const_iterator begin() const; + const_iterator end() const; + + shared_ptr<Container const> ptr; + }; + + enum class Type : uint8_t { + Null = 0, + Float = 1, + Bool = 2, + Int = 3, + String = 4, + Array = 5, + Object = 6 + }; + + static String typeName(Type t); + static Type typeFromName(String const& t); + + static Json ofType(Type t); + + // Parses JSON or JSON sub-type + static Json parse(String const& string); + + // Parses JSON object or array only (the only top level types allowed by + // JSON) + static Json parseJson(String const& json); + + // Constructs type Null + Json(); + + Json(double); + Json(bool); + Json(int); + Json(long); + Json(long long); + Json(unsigned int); + Json(unsigned long); + Json(unsigned long long); + Json(char const*); + Json(String::Char const*); + Json(String::Char const*, size_t); + Json(String); + Json(std::string); + Json(JsonArray); + Json(JsonObject); + + // Float and Int types are convertible between each other. toDouble, + // toFloat, toInt, toUInt may be called on either an Int or a Float. For a + // Float this is simply a C style cast from double, and for an Int it is + // simply a C style cast from int64_t. + // + // Bools, Strings, Arrays, Objects, and Null are not automatically + // convertible to any other type. + + double toDouble() const; + float toFloat() const; + bool toBool() const; + int64_t toInt() const; + uint64_t toUInt() const; + String toString() const; + JsonArray toArray() const; + JsonObject toObject() const; + + // Internally, String, JsonArray, and JsonObject are shared via shared_ptr + // since this class is immutable. Use these methods to get at this pointer + // without causing a copy. + StringConstPtr stringPtr() const; + JsonArrayConstPtr arrayPtr() const; + JsonObjectConstPtr objectPtr() const; + + // As a convenience, make it easy to safely and quickly iterate over a + // JsonArray or JsonObject contents by holding the container pointer. + IteratorWrapper<JsonArray> iterateArray() const; + IteratorWrapper<JsonObject> iterateObject() const; + + // opt* methods work like this, if the json is null, it returns none. If the + // json is convertible, it returns the converted type, otherwise an exception + // occurrs. + Maybe<Json> opt() const; + Maybe<double> optDouble() const; + Maybe<float> optFloat() const; + Maybe<bool> optBool() const; + Maybe<int64_t> optInt() const; + Maybe<uint64_t> optUInt() const; + Maybe<String> optString() const; + Maybe<JsonArray> optArray() const; + Maybe<JsonObject> optObject() const; + + // Size of array / object type json + size_t size() const; + + // If this json is array type, get the value at the given index + Json get(size_t index) const; + double getDouble(size_t index) const; + float getFloat(size_t index) const; + bool getBool(size_t index) const; + int64_t getInt(size_t index) const; + uint64_t getUInt(size_t index) const; + String getString(size_t index) const; + JsonArray getArray(size_t index) const; + JsonObject getObject(size_t index) const; + + // These versions of get* return default value if the index is out of range, + // or if the value pointed to is null. + Json get(size_t index, Json def) const; + double getDouble(size_t index, double def) const; + float getFloat(size_t index, float def) const; + bool getBool(size_t index, bool def) const; + int64_t getInt(size_t index, int64_t def) const; + uint64_t getUInt(size_t index, int64_t def) const; + String getString(size_t index, String def) const; + JsonArray getArray(size_t index, JsonArray def) const; + JsonObject getObject(size_t index, JsonObject def) const; + + // If object type, whether object contains key + bool contains(String const& key) const; + + // If this json is object type, get the value for the given key + Json get(String const& key) const; + double getDouble(String const& key) const; + float getFloat(String const& key) const; + bool getBool(String const& key) const; + int64_t getInt(String const& key) const; + uint64_t getUInt(String const& key) const; + String getString(String const& key) const; + JsonArray getArray(String const& key) const; + JsonObject getObject(String const& key) const; + + // These versions of get* return the default if the key is missing or the + // value is null. + Json get(String const& key, Json def) const; + double getDouble(String const& key, double def) const; + float getFloat(String const& key, float def) const; + bool getBool(String const& key, bool def) const; + int64_t getInt(String const& key, int64_t def) const; + uint64_t getUInt(String const& key, int64_t def) const; + String getString(String const& key, String def) const; + JsonArray getArray(String const& key, JsonArray def) const; + JsonObject getObject(String const& key, JsonObject def) const; + + // Works the same way as opt methods above. Will never return a null value, + // if there is a null entry it will just return an empty Maybe. + Maybe<Json> opt(String const& key) const; + Maybe<double> optDouble(String const& key) const; + Maybe<float> optFloat(String const& key) const; + Maybe<bool> optBool(String const& key) const; + Maybe<int64_t> optInt(String const& key) const; + Maybe<uint64_t> optUInt(String const& key) const; + Maybe<String> optString(String const& key) const; + Maybe<JsonArray> optArray(String const& key) const; + Maybe<JsonObject> optObject(String const& key) const; + + // Combines gets recursively in friendly expressions. For + // example, call like this: json.query("path.to.array[3][4]") + Json query(String const& path) const; + double queryDouble(String const& path) const; + float queryFloat(String const& path) const; + bool queryBool(String const& path) const; + int64_t queryInt(String const& path) const; + uint64_t queryUInt(String const& path) const; + String queryString(String const& path) const; + JsonArray queryArray(String const& path) const; + JsonObject queryObject(String const& path) const; + + // These versions of get* do not throw on missing / null keys anywhere in the + // query path. + Json query(String const& path, Json def) const; + double queryDouble(String const& path, double def) const; + float queryFloat(String const& path, float def) const; + bool queryBool(String const& path, bool def) const; + int64_t queryInt(String const& path, int64_t def) const; + uint64_t queryUInt(String const& path, uint64_t def) const; + String queryString(String const& path, String def) const; + JsonArray queryArray(String const& path, JsonArray def) const; + JsonObject queryObject(String const& path, JsonObject def) const; + + // Returns none on on missing / null keys anywhere in the query path. Will + // never return a null value, just an empty Maybe. + Maybe<Json> optQuery(String const& path) const; + Maybe<double> optQueryDouble(String const& path) const; + Maybe<float> optQueryFloat(String const& path) const; + Maybe<bool> optQueryBool(String const& path) const; + Maybe<int64_t> optQueryInt(String const& path) const; + Maybe<uint64_t> optQueryUInt(String const& path) const; + Maybe<String> optQueryString(String const& path) const; + Maybe<JsonArray> optQueryArray(String const& path) const; + Maybe<JsonObject> optQueryObject(String const& path) const; + + // Returns a *new* object with the given values set/erased. Throws if not an + // object. + Json set(String key, Json value) const; + Json setPath(String path, Json value) const; + Json setAll(JsonObject values) const; + Json eraseKey(String key) const; + Json erasePath(String path) const; + + // Returns a *new* array with the given values set/inserted/appended/erased. + // Throws if not an array. + Json set(size_t index, Json value) const; + Json insert(size_t index, Json value) const; + Json append(Json value) const; + Json eraseIndex(size_t index) const; + + Type type() const; + String typeName() const; + Json convert(Type u) const; + + bool isType(Type type) const; + bool canConvert(Type type) const; + + // isNull returns true when the type of the Json is null. operator bool() is + // the opposite of isNull(). + bool isNull() const; + explicit operator bool() const; + + // Prints JSON or JSON sub-type. If sort is true, then any object anywhere + // inside this value will be sorted alphanumerically before being written, + // resulting in a known *unique* textual representation of the Json that is + // cross-platform. + String repr(int pretty = 0, bool sort = false) const; + // Prints JSON object or array only (only top level types allowed by JSON) + String printJson(int pretty = 0, bool sort = false) const; + + // operator== and operator!= compare for exact equality with all types, and + // additionally equality with numeric conversion with Int <-> Float + bool operator==(Json const& v) const; + bool operator!=(Json const& v) const; + + // Does this Json not share its storage with any other Json? + bool unique() const; + +private: + Json const* ptr(size_t index) const; + Json const* ptr(String const& key) const; + + Variant<Empty, double, bool, int64_t, StringConstPtr, JsonArrayConstPtr, JsonObjectConstPtr> m_data; +}; + +std::ostream& operator<<(std::ostream& os, Json const& v); + +// Fixes ambiguity with OrderedHashMap operator<< +std::ostream& operator<<(std::ostream& os, JsonObject const& v); + +// Serialize json to DataStream. Strings are stored as UTF-8, ints are stored +// as VLQ, doubles as 64 bit. +DataStream& operator<<(DataStream& ds, Json const& v); +DataStream& operator>>(DataStream& ds, Json& v); + +// Convenience methods for Json containers +DataStream& operator<<(DataStream& ds, JsonArray const& l); +DataStream& operator>>(DataStream& ds, JsonArray& l); +DataStream& operator<<(DataStream& ds, JsonObject const& m); +DataStream& operator>>(DataStream& ds, JsonObject& m); + +// Merges the two given json values and returns the result, by the following +// rules (applied in order): If the base value is null, returns the merger. +// If the merger value is null, returns base. For any two non-objects types, +// returns the merger. If both values are objects, then the resulting object +// is the combination of both objects, but for each repeated key jsonMerge is +// called recursively on both values to determine the result. +Json jsonMerge(Json const& base, Json const& merger); + +template <typename... T> +Json jsonMerge(Json const& base, Json const& merger, T const&... rest); + +// Similar to jsonMerge, but query only for a single key. Gets a value equal +// to jsonMerge(jsons...).query(key, Json()), but much faster than doing an +// entire merge operation. +template <typename... T> +Json jsonMergeQuery(String const& key, Json const& first, T const&... rest); + +// jsonMergeQuery with a default. +template <typename... T> +Json jsonMergeQueryDef(String const& key, Json def, Json const& first, T const&... rest); + +template <> +struct hash<Json> { + size_t operator()(Json const& v) const; +}; + +template <typename Container> +auto Json::IteratorWrapper<Container>::begin() const -> const_iterator { + return ptr->begin(); +} + +template <typename Container> +auto Json::IteratorWrapper<Container>::end() const -> const_iterator { + return ptr->end(); +} + +template <typename... T> +Json jsonMerge(Json const& base, Json const& merger, T const&... rest) { + return jsonMerge(jsonMerge(base, merger), rest...); +} + +template <typename... T> +Json jsonMergeQuery(String const&, Json def) { + return def; +} + +template <typename... T> +Json jsonMergeQueryImpl(String const& key, Json const& json) { + return json.query(key, {}); +} + +template <typename... T> +Json jsonMergeQueryImpl(String const& key, Json const& base, Json const& first, T const&... rest) { + Json value = jsonMergeQueryImpl(key, first, rest...); + if (value && !value.isType(Json::Type::Object)) + return value; + return jsonMerge(base.query(key, {}), value); +} + +template <typename... T> +Json jsonMergeQuery(String const& key, Json const& first, T const&... rest) { + return jsonMergeQueryImpl(key, first, rest...); +} + +template <typename... T> +Json jsonMergeQueryDef(String const& key, Json def, Json const& first, T const&... rest) { + if (auto v = jsonMergeQueryImpl(key, first, rest...)) + return v; + return def; +} + +} + +#endif diff --git a/source/core/StarJsonBuilder.cpp b/source/core/StarJsonBuilder.cpp new file mode 100644 index 0000000..432cef9 --- /dev/null +++ b/source/core/StarJsonBuilder.cpp @@ -0,0 +1,168 @@ +#include "StarJsonBuilder.hpp" +#include "StarLexicalCast.hpp" + +namespace Star { + +void JsonBuilderStream::beginObject() { + pushSentry(); +} + +void JsonBuilderStream::objectKey(char32_t const* s, size_t len) { + push(Json(s, len)); +} + +void JsonBuilderStream::endObject() { + JsonObject object; + while (true) { + if (isSentry()) { + set(Json(move(object))); + return; + } else { + Json v = pop(); + String k = pop().toString(); + if (!object.insert(k, move(v)).second) + throw JsonParsingException(strf("Json object contains a duplicate entry for key '%s'", k)); + } + } +} + +void JsonBuilderStream::beginArray() { + pushSentry(); +} + +void JsonBuilderStream::endArray() { + JsonArray array; + while (true) { + if (isSentry()) { + array.reverse(); + set(Json(move(array))); + return; + } else { + array.append(pop()); + } + } +} + +void JsonBuilderStream::putString(char32_t const* s, size_t len) { + push(Json(s, len)); +} + +void JsonBuilderStream::putDouble(char32_t const* s, size_t len) { + push(Json(lexicalCast<double>(String(s, len)))); +} + +void JsonBuilderStream::putInteger(char32_t const* s, size_t len) { + push(Json(lexicalCast<long long>(String(s, len)))); +} + +void JsonBuilderStream::putBoolean(bool b) { + push(Json(b)); +} + +void JsonBuilderStream::putNull() { + push(Json()); +} + +void JsonBuilderStream::putWhitespace(char32_t const*, size_t) {} + +void JsonBuilderStream::putColon() {} + +void JsonBuilderStream::putComma() {} + +size_t JsonBuilderStream::stackSize() { + return m_stack.size(); +} + +Json JsonBuilderStream::takeTop() { + if (m_stack.size()) + return m_stack.takeLast().take(); + else + return Json(); +} + +void JsonBuilderStream::push(Json v) { + m_stack.append(move(v)); +} + +Json JsonBuilderStream::pop() { + return m_stack.takeLast().take(); +} + +void JsonBuilderStream::set(Json v) { + m_stack.last() = move(v); +} + +void JsonBuilderStream::pushSentry() { + m_stack.append({}); +} + +bool JsonBuilderStream::isSentry() { + return !m_stack.empty() && !m_stack.last(); +} + +void JsonStreamer<Json>::toJsonStream(Json const& val, JsonStream& stream, bool sort) { + Json::Type type = val.type(); + if (type == Json::Type::Null) { + stream.putNull(); + } else if (type == Json::Type::Float) { + auto d = String(toString(val.toDouble())).wideString(); + stream.putDouble(d.c_str(), d.length()); + } else if (type == Json::Type::Bool) { + stream.putBoolean(val.toBool()); + } else if (type == Json::Type::Int) { + auto i = String(toString(val.toInt())).wideString(); + stream.putInteger(i.c_str(), i.length()); + } else if (type == Json::Type::String) { + auto ws = val.toString().wideString(); + stream.putString(ws.c_str(), ws.length()); + } else if (type == Json::Type::Array) { + stream.beginArray(); + bool first = true; + for (auto const& elem : val.iterateArray()) { + if (!first) + stream.putComma(); + first = false; + toJsonStream(elem, stream, sort); + } + stream.endArray(); + } else if (type == Json::Type::Object) { + stream.beginObject(); + List<String::Char> chars; + if (sort) { + auto objectPtr = val.objectPtr(); + List<JsonObject::const_iterator> iterators; + iterators.reserve(objectPtr->size()); + for (auto i = objectPtr->begin(); i != objectPtr->end(); ++i) + iterators.append(i); + iterators.sort([](JsonObject::const_iterator a, JsonObject::const_iterator b) { + return a->first < b->first; + }); + bool first = true; + for (auto const& i : iterators) { + if (!first) + stream.putComma(); + first = false; + chars.clear(); + for (auto const& c : i->first) + chars.push_back(c); + stream.objectKey(chars.ptr(), chars.size()); + stream.putColon(); + toJsonStream(i->second, stream, sort); + } + } else { + bool first = true; + for (auto const& pair : val.iterateObject()) { + if (!first) + stream.putComma(); + first = false; + auto ws = pair.first.wideString(); + stream.objectKey(ws.c_str(), ws.length()); + stream.putColon(); + toJsonStream(pair.second, stream, sort); + } + } + stream.endObject(); + } +} + +} diff --git a/source/core/StarJsonBuilder.hpp b/source/core/StarJsonBuilder.hpp new file mode 100644 index 0000000..cd2d42d --- /dev/null +++ b/source/core/StarJsonBuilder.hpp @@ -0,0 +1,104 @@ +#ifndef STAR_JSON_BUILDER_HPP +#define STAR_JSON_BUILDER_HPP + +#include "StarJsonParser.hpp" +#include "StarJson.hpp" + +namespace Star { + +class JsonBuilderStream : public JsonStream { +public: + virtual void beginObject(); + virtual void objectKey(char32_t const* s, size_t len); + virtual void endObject(); + + virtual void beginArray(); + virtual void endArray(); + + virtual void putString(char32_t const* s, size_t len); + virtual void putDouble(char32_t const* s, size_t len); + virtual void putInteger(char32_t const* s, size_t len); + virtual void putBoolean(bool b); + virtual void putNull(); + + virtual void putWhitespace(char32_t const* s, size_t len); + virtual void putColon(); + virtual void putComma(); + + size_t stackSize(); + Json takeTop(); + +private: + void push(Json v); + Json pop(); + void set(Json v); + void pushSentry(); + bool isSentry(); + + List<Maybe<Json>> m_stack; +}; + +template <typename Jsonlike> +class JsonStreamer { +public: + static void toJsonStream(Jsonlike const& val, JsonStream& stream, bool sort); +}; + +template <> +class JsonStreamer<Json> { +public: + static void toJsonStream(Json const& val, JsonStream& stream, bool sort); +}; + +template <typename InputIterator> +Json inputUtf8Json(InputIterator begin, InputIterator end, bool fragment) { + typedef U8ToU32Iterator<InputIterator> Utf32Input; + typedef JsonParser<Utf32Input> Parser; + + JsonBuilderStream stream; + Parser parser(stream); + Utf32Input wbegin(begin); + Utf32Input wend(end); + Utf32Input pend = parser.parse(wbegin, wend, fragment); + + if (parser.error()) + throw JsonParsingException(strf("Error parsing json: %s at %s:%s", parser.error(), parser.line(), parser.column())); + else if (pend != wend) + throw JsonParsingException(strf("Error extra data at end of input at %s:%s", parser.line(), parser.column())); + + return stream.takeTop(); +} + +template <typename OutputIterator> +void outputUtf8Json(Json const& val, OutputIterator out, int pretty, bool sort) { + typedef Utf8OutputIterator<OutputIterator> Utf8Output; + typedef JsonWriter<Utf8Output> Writer; + Writer writer(Utf8Output(out), pretty); + JsonStreamer<Json>::toJsonStream(val, writer, sort); +} + +template <typename InputIterator, typename Stream = JsonBuilderStream, typename Jsonlike = Json> +Jsonlike inputUtf32Json(InputIterator begin, InputIterator end, bool fragment) { + Stream stream; + JsonParser<InputIterator> parser(stream); + + InputIterator pend = parser.parse(begin, end, fragment); + + if (parser.error()) { + throw JsonParsingException(strf("Error parsing json: %s at %s:%s", parser.error(), parser.line(), parser.column())); + } else if (pend != end) { + throw JsonParsingException(strf("Error extra data at end of input at %s:%s", parser.line(), parser.column())); + } + + return stream.takeTop(); +} + +template <typename OutputIterator, typename Jsonlike = Json> +void outputUtf32Json(Jsonlike const& val, OutputIterator out, int pretty, bool sort) { + JsonWriter<OutputIterator> writer(out, pretty); + JsonStreamer<Jsonlike>::toJsonStream(val, writer, sort); +} + +} + +#endif diff --git a/source/core/StarJsonExtra.cpp b/source/core/StarJsonExtra.cpp new file mode 100644 index 0000000..7a660d1 --- /dev/null +++ b/source/core/StarJsonExtra.cpp @@ -0,0 +1,458 @@ +#include "StarJsonExtra.hpp" +#include "StarLogging.hpp" +#include "StarRandom.hpp" + +namespace Star { + +size_t jsonToSize(Json const& v) { + if (v.isNull()) + return NPos; + + if (!v.canConvert(Json::Type::Int)) + throw JsonException("Json not an int in jsonToSize"); + + return v.toUInt(); +} + +Json jsonFromSize(size_t s) { + if (s == NPos) + return Json(); + return Json(s); +} + +Vec2D jsonToVec2D(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 2) + throw JsonException("Json not an array of size 2 in jsonToVec2D"); + + return Vec2D(v.getDouble(0), v.getDouble(1)); +} + +Vec2F jsonToVec2F(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 2) + throw JsonException("Json not an array of size 2 in jsonToVec2F"); + + return Vec2F(v.getFloat(0), v.getFloat(1)); +} + +Json jsonFromVec2F(Vec2F const& v) { + return JsonArray{v[0], v[1]}; +} + +Vec2I jsonToVec2I(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 2) + throw JsonException("Json not an array of size 2 in jsonToVec2I"); + + return Vec2I(v.getInt(0), v.getInt(1)); +} + +Json jsonFromVec2I(Vec2I const& v) { + return JsonArray{v[0], v[1]}; +} + +Vec2U jsonToVec2U(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 2) + throw JsonException("Json not an array of size 2 in jsonToVec2I"); + + return Vec2U(v.getInt(0), v.getInt(1)); +} + +Json jsonFromVec2U(Vec2U const& v) { + return JsonArray{v[0], v[1]}; +} + +Vec2B jsonToVec2B(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 2) + throw JsonException("Json not an array of size 2 in jsonToVec2B"); + + return Vec2B(v.getInt(0), v.getInt(1)); +} + +Json jsonFromVec2B(Vec2B const& v) { + return JsonArray{v[0], v[1]}; +} + +Vec3D jsonToVec3D(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 3) + throw JsonException("Json not an array of size size 3 in jsonToVec3D"); + + return Vec3D(v.getDouble(0), v.getDouble(1), v.getDouble(2)); +} + +Vec3F jsonToVec3F(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 3) + throw JsonException("Json not an array of size 3 in jsonToVec3D"); + + return Vec3F(v.getFloat(0), v.getFloat(1), v.getFloat(2)); +} + +Json jsonFromVec3F(Vec3F const& v) { + return JsonArray{v[0], v[1], v[2]}; +} + +Vec3I jsonToVec3I(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 3) + throw JsonException("Json not an array of size 3 in jsonToVec3I"); + + return Vec3I(v.getInt(0), v.getInt(1), v.getInt(2)); +} + +Json jsonFromVec3I(Vec3I const& v) { + JsonArray result; + result.append(v[0]); + result.append(v[1]); + result.append(v[2]); + return result; +} + +Vec3B jsonToVec3B(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 3) + throw JsonException("Json not an array of size 3 in jsonToVec3B"); + + return Vec3B(v.getInt(0), v.getInt(1), v.getInt(2)); +} + +Vec4B jsonToVec4B(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 4) + throw JsonException("Json not an array of size 4 in jsonToVec4B"); + + return Vec4B(v.getInt(0), v.getInt(1), v.getInt(2), v.getInt(3)); +} + +Vec4I jsonToVec4I(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 4) + throw JsonException("Json not an array of size 4 in jsonToVec4B"); + + return Vec4I(v.getInt(0), v.getInt(1), v.getInt(2), v.getInt(3)); +} + +Vec4F jsonToVec4F(Json const& v) { + if (v.type() != Json::Type::Array || v.size() != 4) + throw JsonException("Json not an array of size 4 in jsonToVec4B"); + + return Vec4F(v.getFloat(0), v.getFloat(1), v.getFloat(2), v.getFloat(3)); +} + +RectD jsonToRectD(Json const& v) { + if (v.type() != Json::Type::Array) + throw JsonException("Json not an array in jsonToRectD"); + + if (v.size() != 4 && v.size() != 2) + throw JsonException("Json not an array of proper size in jsonToRectD"); + + if (v.size() == 4) + return RectD(v.getDouble(0), v.getDouble(1), v.getDouble(2), v.getDouble(3)); + + try { + auto lowerLeft = jsonToVec2D(v.get(0)); + auto upperRight = jsonToVec2D(v.get(1)); + return RectD(lowerLeft, upperRight); + } catch (JsonException const& e) { + throw JsonException(strf("Inner position not well formed in jsonToRectD: %s", outputException(e, true))); + } +} + +Json jsonFromRectD(RectD const& rect) { + return JsonArray{rect.xMin(), rect.yMin(), rect.xMax(), rect.yMax()}; +} + +RectF jsonToRectF(Json const& v) { + return RectF(jsonToRectD(v)); +} + +Json jsonFromRectF(RectF const& rect) { + return JsonArray{rect.xMin(), rect.yMin(), rect.xMax(), rect.yMax()}; +} + +RectI jsonToRectI(Json const& v) { + if (v.type() != Json::Type::Array) + throw JsonException("Json not an array in jsonToRectI"); + + if (v.size() != 4 && v.size() != 2) + throw JsonException("Json not an array of proper size in jsonToRectI"); + + if (v.size() == 4) + return RectI(v.getInt(0), v.getInt(1), v.getInt(2), v.getInt(3)); + + try { + auto lowerLeft = jsonToVec2I(v.get(0)); + auto upperRight = jsonToVec2I(v.get(1)); + return RectI(lowerLeft, upperRight); + } catch (JsonException const& e) { + throw JsonException(strf("Inner position not well formed in jsonToRectI: %s", outputException(e, true))); + } +} + +Json jsonFromRectI(RectI const& rect) { + return JsonArray{rect.xMin(), rect.yMin(), rect.xMax(), rect.yMax()}; +} + +RectU jsonToRectU(Json const& v) { + if (v.type() != Json::Type::Array) + throw JsonException("Json not an array in jsonToRectU"); + + if (v.size() != 4 && v.size() != 2) + throw JsonException("Json not an array of proper size in jsonToRectU"); + + if (v.size() == 4) + return RectU(v.getInt(0), v.getUInt(1), v.getUInt(2), v.getUInt(3)); + + try { + auto lowerLeft = jsonToVec2U(v.get(0)); + auto upperRight = jsonToVec2U(v.get(1)); + return RectU(lowerLeft, upperRight); + } catch (JsonException const& e) { + throw JsonException(strf("Inner position not well formed in jsonToRectU: %s", outputException(e, true))); + } +} + +Json jsonFromRectU(RectU const& rect) { + return JsonArray{rect.xMin(), rect.yMin(), rect.xMax(), rect.yMax()}; +} + +Color jsonToColor(Json const& v) { + if (v.type() == Json::Type::Array) { + if (v.type() != Json::Type::Array || (v.size() != 3 && v.size() != 4)) + throw JsonException("Json not an array of size 3 or 4 in jsonToColor"); + Color c = Color::rgba(0, 0, 0, 255); + + c.setRed(v.getInt(0)); + c.setGreen(v.getInt(1)); + c.setBlue(v.getInt(2)); + + if (v.size() == 4) + c.setAlpha(v.getInt(3)); + + return c; + } else if (v.type() == Json::Type::String) { + return Color(v.toString()); + } else { + throw JsonException(strf("Json of type %s cannot be converted to color", v.typeName())); + } +} + +Json jsonFromColor(Color const& color) { + JsonArray result; + result.push_back(color.red()); + result.push_back(color.green()); + result.push_back(color.blue()); + if (color.alpha() != 255) { + result.push_back(color.alpha()); + } + return result; +} + +PolyD jsonToPolyD(Json const& v) { + PolyD poly; + + for (Json const& vertex : v.iterateArray()) + poly.add(jsonToVec2D(vertex)); + + return fixInsideOutPoly(poly); +} + +PolyF jsonToPolyF(Json const& v) { + PolyF poly; + + for (Json const& vertex : v.iterateArray()) + poly.add(jsonToVec2F(vertex)); + + return fixInsideOutPoly(poly); +} + +PolyI jsonToPolyI(Json const& v) { + PolyI poly; + + for (Json const& vertex : v.iterateArray()) + poly.add(jsonToVec2I(vertex)); + + return fixInsideOutPoly(poly); +} + +Json jsonFromPolyF(PolyF const& poly) { + JsonArray vertexList; + for (auto const& vertex : poly.vertexes()) + vertexList.append(JsonArray{vertex[0], vertex[1]}); + + return vertexList; +} + +Line2F jsonToLine2F(Json const& v) { + return Line2F(jsonToVec2F(v.get(0)), jsonToVec2F(v.get(1))); +} + +Json jsonFromLine2F(Line2F const& line) { + return JsonArray{jsonFromVec2F(line.min()), jsonFromVec2F(line.max())}; +} + +Mat3F jsonToMat3F(Json const& v) { + return Mat3F(jsonToVec3F(v.get(0)), jsonToVec3F(v.get(1)), jsonToVec3F(v.get(2))); +} + +Json jsonFromMat3F(Mat3F const& v) { + return JsonArray{jsonFromVec3F(v[0]), jsonFromVec3F(v[1]), jsonFromVec3F(v[2])}; +} + +StringList jsonToStringList(Json const& v) { + StringList result; + for (auto const& entry : v.iterateArray()) + result.push_back(entry.toString()); + return result; +} + +Json jsonFromStringList(List<String> const& v) { + JsonArray result; + for (auto e : v) + result.push_back(e); + return result; +} + +List<float> jsonToFloatList(Json const& v) { + List<float> result; + for (auto const& entry : v.iterateArray()) + result.push_back(entry.toFloat()); + return result; +} + +StringSet jsonToStringSet(Json const& v) { + StringSet result; + for (auto const& entry : v.iterateArray()) + result.add(entry.toString()); + return result; +} + +Json jsonFromStringSet(StringSet const& v) { + JsonArray result; + for (auto e : v) + result.push_back(e); + return result; +} + +List<int> jsonToIntList(Json const& v) { + List<int> result; + for (auto const& entry : v.iterateArray()) + result.push_back(entry.toInt()); + return result; +} + +List<Vec2I> jsonToVec2IList(Json const& v) { + List<Vec2I> result; + for (auto const& entry : v.iterateArray()) + result.append(jsonToVec2I(entry)); + return result; +} + +List<Vec2U> jsonToVec2UList(Json const& v) { + List<Vec2U> result; + for (auto const& entry : v.iterateArray()) + result.append(jsonToVec2U(entry)); + return result; +} + +List<Vec2F> jsonToVec2FList(Json const& v) { + List<Vec2F> result; + for (auto const& entry : v.iterateArray()) + result.append(jsonToVec2F(entry)); + return result; +} + +List<Vec4B> jsonToVec4BList(Json const& v) { + List<Vec4B> result; + for (auto const& entry : v.iterateArray()) + result.append(jsonToVec4B(entry)); + return result; +} + +List<Color> jsonToColorList(Json const& v) { + List<Color> result; + for (auto const& entry : v.iterateArray()) + result.append(jsonToColor(entry)); + return result; +} + +Json weightedChoiceFromJson(Json const& source, Json const& default_) { + if (source.isNull()) + return default_; + if (source.type() != Json::Type::Array) + throw StarException("Json of array type expected."); + List<pair<float, Json>> options; + float sum = 0; + size_t idx = 0; + while (idx < source.size()) { + float weight = 1; + Json entry = source.get(idx); + if (entry.type() == Json::Type::Int || entry.type() == Json::Type::Float) { + weight = entry.toDouble(); + idx++; + if (idx >= source.size()) + throw StarException("Weighted companion cube cannot cry."); + sum += weight; + options.append(pair<float, Json>{weight, source.get(idx)}); + } else { + sum += weight; + options.append(pair<float, Json>{weight, entry}); + } + idx++; + } + if (!options.size()) + return default_; + float choice = Random::randf() * sum; + idx = 0; + while (idx < options.size()) { + auto const& entry = options[idx]; + if (entry.first >= choice) + return entry.second; + choice -= entry.first; + idx++; + } + return options[options.size() - 1].second; +} + +Json binnedChoiceFromJson(Json const& bins, float target, Json const& def) { + JsonArray binList = bins.toArray(); + sortByComputedValue(binList, [](Json const& pair) { return -pair.getFloat(0); }); + Json result = def; + for (auto const& pair : binList) { + if (pair.getFloat(0) <= target) { + result = pair.get(1); + break; + } + } + return result; +} + +template <> +WeightedPool<int> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<int>(source, [](Json const& v) { return v.toInt(); }); +} + +template <> +WeightedPool<unsigned> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<unsigned>(source, [](Json const& v) { return v.toUInt(); }); +} + +template <> +WeightedPool<float> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<float>(source, [](Json const& v) { return v.toFloat(); }); +} + +template <> +WeightedPool<double> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<double>(source, [](Json const& v) { return v.toDouble(); }); +} + +template <> +WeightedPool<String> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<String>(source, [](Json const& v) { return v.toString(); }); +} + +template <> +WeightedPool<JsonArray> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<JsonArray>(source, [](Json const& v) { return v.toArray(); }); +} + +template <> +WeightedPool<JsonObject> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<JsonObject>(source, [](Json const& v) { return v.toObject(); }); +} + +} diff --git a/source/core/StarJsonExtra.hpp b/source/core/StarJsonExtra.hpp new file mode 100644 index 0000000..3bf5b9c --- /dev/null +++ b/source/core/StarJsonExtra.hpp @@ -0,0 +1,384 @@ +#ifndef STAR_JSON_EXTRA_HPP +#define STAR_JSON_EXTRA_HPP + +#include "StarJson.hpp" +#include "StarPoly.hpp" +#include "StarColor.hpp" +#include "StarSet.hpp" +#include "StarWeightedPool.hpp" + +namespace Star { + +// Extra methods to parse a variety of types out of pure JSON. Throws +// JsonException if json is not of correct type or size. + +size_t jsonToSize(Json const& v); +Json jsonFromSize(size_t s); + +// Must be array of appropriate size. + +Vec2D jsonToVec2D(Json const& v); +Vec2F jsonToVec2F(Json const& v); +Json jsonFromVec2F(Vec2F const& v); +Vec2I jsonToVec2I(Json const& v); +Json jsonFromVec2I(Vec2I const& v); +Vec2U jsonToVec2U(Json const& v); +Json jsonFromVec2U(Vec2U const& v); +Vec2B jsonToVec2B(Json const& v); +Json jsonFromVec2B(Vec2B const& v); + +Vec3D jsonToVec3D(Json const& v); +Vec3F jsonToVec3F(Json const& v); +Json jsonFromVec3F(Vec3F const& v); +Vec3I jsonToVec3I(Json const& v); +Json jsonFromVec3I(Vec3I const& v); +Vec3B jsonToVec3B(Json const& v); + +Vec4B jsonToVec4B(Json const& v); +Vec4I jsonToVec4I(Json const& v); +Vec4F jsonToVec4F(Json const& v); + +// Must be array of size 4 or 2 arrays of size 2 in an array. +RectD jsonToRectD(Json const& v); +Json jsonFromRectD(RectD const& rect); +RectF jsonToRectF(Json const& v); +Json jsonFromRectF(RectF const& rect); +RectI jsonToRectI(Json const& v); +Json jsonFromRectI(RectI const& rect); +RectU jsonToRectU(Json const& v); +Json jsonFromRectU(RectU const& rect); + +// Can be a string, array of size 3 or 4 of doubles or ints. If double, range +// is 0.0 to 1.0, if int range is 0-255 +Color jsonToColor(Json const& v); +Json jsonFromColor(Color const& color); + +// HACK: Fix clockwise specified polygons in coming from JSON +template <typename Float> +Polygon<Float> fixInsideOutPoly(Polygon<Float> p); + +// Array of size 2 arrays +PolyD jsonToPolyD(Json const& v); +PolyF jsonToPolyF(Json const& v); +PolyI jsonToPolyI(Json const& v); +Json jsonFromPolyF(PolyF const& poly); + +// Expects a size 2 array of size 2 arrays +Line2F jsonToLine2F(Json const& v); +Json jsonFromLine2F(Line2F const& line); + +Mat3F jsonToMat3F(Json const& v); +Json jsonFromMat3F(Mat3F const& v); + +StringList jsonToStringList(Json const& v); +Json jsonFromStringList(List<String> const& v); +StringSet jsonToStringSet(Json const& v); +Json jsonFromStringSet(StringSet const& v); +List<float> jsonToFloatList(Json const& v); +List<int> jsonToIntList(Json const& v); +List<Vec2I> jsonToVec2IList(Json const& v); +List<Vec2U> jsonToVec2UList(Json const& v); +List<Vec2F> jsonToVec2FList(Json const& v); +List<Vec4B> jsonToVec4BList(Json const& v); +List<Color> jsonToColorList(Json const& v); + +Json weightedChoiceFromJson(Json const& source, Json const& default_); + +// Assumes that the bins parameter is an array of pairs (arrays), where the +// first element is a minimum value and the second element is the actual +// important value. Finds the pair with the highest value that is less than or +// equal to the given target, and returns the second element. +Json binnedChoiceFromJson(Json const& bins, float target, Json const& def = Json()); + +template <typename T> +WeightedPool<T> jsonToWeightedPool(Json const& source); +template <typename T, typename Converter> +WeightedPool<T> jsonToWeightedPool(Json const& source, Converter&& converter); + +template <typename T> +Json jsonFromWeightedPool(WeightedPool<T> const& pool); +template <typename T, typename Converter> +Json jsonFromWeightedPool(WeightedPool<T> const& pool, Converter&& converter); + +template <size_t Size> +Array<unsigned, Size> jsonToArrayU(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToArrayU", Size).c_str()); + + Array<unsigned, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getUInt(i); + } + + return res; +} + +template <size_t Size> +Array<size_t, Size> jsonToArrayS(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToArrayS", Size).c_str()); + + Array<size_t, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getUInt(i); + } + + return res; +} + +template <size_t Size> +Array<int, Size> jsonToArrayI(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToArrayI", Size).c_str()); + + Array<int, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getInt(i); + } + + return res; +} + +template <size_t Size> +Array<float, Size> jsonToArrayF(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToArrayF", Size).c_str()); + + Array<float, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getFloat(i); + } + + return res; +} + +template <size_t Size> +Array<double, Size> jsonToArrayD(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToArrayD", Size).c_str()); + + Array<double, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getDouble(i); + } + + return res; +} + +template <size_t Size> +Array<String, Size> jsonToStringArray(Json const& v) { + if (v.size() != Size) + throw JsonException(strf("Json array not of size %d in jsonToStringArray", Size).c_str()); + + Array<String, Size> res; + for (size_t i = 0; i < Size; i++) { + res[i] = v.getString(i); + } + + return res; +} + +template <typename Value> +List<Value> jsonToList(Json const& v) { + return jsonToList<Value>(v, construct<Value>()); +} + +template <typename Value, typename Converter> +List<Value> jsonToList(Json const& v, Converter&& valueConvert) { + if (v.type() != Json::Type::Array) + throw JsonException("Json type is not a array in jsonToList"); + + List<Value> res; + for (auto const& entry : v.iterateArray()) + res.push_back(valueConvert(entry)); + + return res; +} + +template <typename Value> +Json jsonFromList(List<Value> const& list) { + return jsonFromList<Value>(list, construct<Json>()); +} + +template <typename Value, typename Converter> +Json jsonFromList(List<Value> const& list, Converter&& valueConvert) { + JsonArray res; + for (auto const& entry : list) + res.push_back(valueConvert(entry)); + + return res; +} + +template <typename Value> +Set<Value> jsonToSet(Json const& v) { + return jsonToSet<Value>(v, construct<Value>()); +} + +template <typename Value, typename Converter> +Set<Value> jsonToSet(Json const& v, Converter&& valueConvert) { + if (v.type() != Json::Type::Array) + throw JsonException("Json type is not an array in jsonToSet"); + + Set<Value> res; + for (auto const& entry : v.iterateArray()) + res.add(valueConvert(entry)); + + return res; +} + +template <typename Value> +Json jsonFromSet(Set<Value> const& Set) { + return jsonFromSet<Value>(Set, construct<Json>()); +} + +template <typename Value, typename Converter> +Json jsonFromSet(Set<Value> const& Set, Converter&& valueConvert) { + JsonArray res; + for (auto entry : Set) + res.push_back(valueConvert(entry)); + + return res; +} + +template <typename MapType, typename KeyConverter, typename ValueConverter> +MapType jsonToMapKV(Json const& v, KeyConverter&& keyConvert, ValueConverter&& valueConvert) { + if (v.type() != Json::Type::Object) + throw JsonException("Json type is not an object in jsonToMap"); + + MapType res; + for (auto const& pair : v.iterateObject()) + res.add(keyConvert(pair.first), valueConvert(pair.second)); + + return res; +} + +template <typename MapType, typename KeyConverter> +MapType jsonToMapK(Json const& v, KeyConverter&& keyConvert) { + return jsonToMapKV<MapType>(v, forward<KeyConverter>(keyConvert), construct<typename MapType::mapped_type>()); +} + +template <typename MapType, typename ValueConverter> +MapType jsonToMapV(Json const& v, ValueConverter&& valueConvert) { + return jsonToMapKV<MapType>(v, construct<typename MapType::key_type>(), forward<ValueConverter>(valueConvert)); +} + +template <typename MapType> +MapType jsonToMap(Json const& v) { + return jsonToMapKV<MapType>(v, construct<typename MapType::key_type>(), construct<typename MapType::mapped_type>()); +} + +template <typename MapType, typename KeyConverter, typename ValueConverter> +Json jsonFromMapKV(MapType const& map, KeyConverter&& keyConvert, ValueConverter&& valueConvert) { + JsonObject res; + for (auto pair : map) + res[keyConvert(pair.first)] = valueConvert(pair.second); + + return res; +} + +template <typename MapType, typename KeyConverter> +Json jsonFromMapK(MapType const& map, KeyConverter&& keyConvert) { + return jsonFromMapKV<MapType>(map, forward<KeyConverter>(keyConvert), construct<Json>()); +} + +template <typename MapType, typename ValueConverter> +Json jsonFromMapV(MapType const& map, ValueConverter&& valueConvert) { + return jsonFromMapKV<MapType>(map, construct<String>(), forward<ValueConverter>(valueConvert)); +} + +template <typename MapType> +Json jsonFromMap(MapType const& map) { + return jsonFromMapKV<MapType>(map, construct<String>(), construct<Json>()); +} + +template <typename T, typename Converter> +Json jsonFromMaybe(Maybe<T> const& m, Converter&& converter) { + return m.apply(converter).value(); +} + +template <typename T> +Json jsonFromMaybe(Maybe<T> const& m) { + return jsonFromMaybe(m, construct<Json>()); +} + +template <typename T, typename Converter> +Maybe<T> jsonToMaybe(Json v, Converter&& converter) { + if (v.isNull()) + return {}; + return converter(v); +} + +template <typename T> +Maybe<T> jsonToMaybe(Json const& v) { + return jsonToMaybe<T>(v, construct<T>()); +} + +template <typename T> +WeightedPool<T> jsonToWeightedPool(Json const& source) { + return jsonToWeightedPool<T>(source, construct<T>()); +} + +template <typename T, typename Converter> +WeightedPool<T> jsonToWeightedPool(Json const& source, Converter&& converter) { + WeightedPool<T> res; + if (source.isNull()) + return res; + for (auto entry : source.iterateArray()) { + if (entry.isType(Json::Type::Array)) + res.add(entry.get(0).toDouble(), converter(entry.get(1))); + else + res.add(entry.getDouble("weight"), converter(entry.get("item"))); + } + + return res; +} + +template <typename T> +Json jsonFromWeightedPool(WeightedPool<T> const& pool) { + return jsonFromWeightedPool<T>(pool, construct<Json>()); +} + +template <typename T, typename Converter> +Json jsonFromWeightedPool(WeightedPool<T> const& pool, Converter&& converter) { + JsonArray res; + for (auto const& pair : pool.items()) { + res.append(JsonObject{ + {"weight", pair.first}, {"item", converter(pair.second)}, + }); + } + return res; +} + +template <> +WeightedPool<int> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<unsigned> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<float> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<double> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<String> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<JsonArray> jsonToWeightedPool(Json const& source); + +template <> +WeightedPool<JsonObject> jsonToWeightedPool(Json const& source); + +template <typename Float> +Polygon<Float> fixInsideOutPoly(Polygon<Float> p) { + if (p.sides() > 2) { + if ((p.side(1).diff() ^ p.side(0).diff()) > 0) + reverse(p.vertexes()); + } + return p; +} + +} + +#endif diff --git a/source/core/StarJsonParser.hpp b/source/core/StarJsonParser.hpp new file mode 100644 index 0000000..87a4bd9 --- /dev/null +++ b/source/core/StarJsonParser.hpp @@ -0,0 +1,733 @@ +#ifndef STAR_JSON_PARSER_HPP +#define STAR_JSON_PARSER_HPP + +#include <vector> + +#include "StarUnicode.hpp" + +namespace Star { + +struct JsonStream { + virtual ~JsonStream() {} + + virtual void beginObject() = 0; + virtual void objectKey(char32_t const*, size_t) = 0; + virtual void endObject() = 0; + + virtual void beginArray() = 0; + virtual void endArray() = 0; + + virtual void putString(char32_t const*, size_t) = 0; + virtual void putDouble(char32_t const*, size_t) = 0; + virtual void putInteger(char32_t const*, size_t) = 0; + virtual void putBoolean(bool) = 0; + virtual void putNull() = 0; + + virtual void putWhitespace(char32_t const*, size_t) = 0; + virtual void putColon() = 0; + virtual void putComma() = 0; +}; + +// Will parse JSON and output to a given JsonStream. Parses an *extension* to +// the JSON format that includes comments. +template <typename InputIterator> +class JsonParser { +public: + JsonParser(JsonStream& stream) + : m_line(0), m_column(0), m_stream(stream) {} + virtual ~JsonParser() {} + + // Does not throw. On error, returned iterator will not be equal to end, and + // error() will be non-null. Set fragment to true to parse any JSON type + // rather than just object or array. + InputIterator parse(InputIterator begin, InputIterator end, bool fragment = false) { + init(begin, end); + + try { + white(); + if (fragment) + value(); + else + top(); + white(); + } catch (ParsingException const&) { + } + + return m_current; + } + + // Human readable parsing error, does not include line or column info. + char const* error() const { + if (m_error.empty()) + return nullptr; + else + return m_error.c_str(); + } + + size_t line() const { + return m_line + 1; + } + + size_t column() const { + return m_column + 1; + } + +private: + typedef std::basic_string<char32_t> CharArray; + + // Thrown internally to abort parsing. + class ParsingException {}; + + void top() { + switch (m_char) { + case '{': + object(); + break; + case '[': + array(); + break; + default: + error("expected JSON object or array at top level"); + return; + } + } + + void value() { + switch (m_char) { + case '{': + object(); + break; + case '[': + array(); + break; + case '"': + string(); + break; + case '-': + number(); + break; + case 0: + error("unexpected end of stream parsing value"); + return; + default: + m_char >= '0' && m_char <= '9' ? number() : word(); + break; + } + } + + void object() { + if (m_char != '{') + error("bad object, should be '{'"); + + next(); + m_stream.beginObject(); + + white(); + if (m_char == '}') { + next(); + m_stream.endObject(); + return; + } + + while (true) { + CharArray s = parseString(); + m_stream.objectKey(s.c_str(), s.length()); + + white(); + if (m_char != ':') + error("bad object, should be ':'"); + next(); + m_stream.putColon(); + white(); + + value(); + + white(); + if (m_char == '}') { + next(); + m_stream.endObject(); + return; + } else if (m_char == ',') { + next(); + m_stream.putComma(); + white(); + } else if (m_char == 0) { + error("unexpected end of stream parsing object."); + } else { + error("bad object, should be '}' or ','"); + } + } + } + + void array() { + if (m_char == '[') { + next(); + m_stream.beginArray(); + white(); + if (m_char == ']') { + next(); + m_stream.endArray(); + } else { + while (true) { + value(); + white(); + if (m_char == ']') { + next(); + m_stream.endArray(); + break; + } else if (m_char == ',') { + next(); + m_stream.putComma(); + white(); + } else if (m_char == 0) { + error("unexpected end of stream parsing array."); + } else { + error("bad array, should be ',' or ']'"); + } + } + } + } else { + error("bad array"); + } + } + + void string() { + CharArray s = parseString(); + m_stream.putString(s.c_str(), s.length()); + } + + void number() { + std::basic_string<char32_t> buffer; + bool hasDot = false; + + if (m_char == '-') { + buffer += '-'; + next(); + } + + if (m_char == '0') { + buffer += '0'; + next(); + } else if (m_char > '0' && m_char <= '9') { + while (m_char >= '0' && m_char <= '9') { + buffer += m_char; + next(); + } + } else { + error("bad number, must start with digit"); + } + + if (m_char == '.') { + hasDot = true; + buffer += '.'; + next(); + while (m_char >= '0' && m_char <= '9') { + buffer += m_char; + next(); + } + } + + if (m_char == 'e' || m_char == 'E') { + buffer += m_char; + next(); + if (m_char == '-' || m_char == '+') { + buffer += m_char; + next(); + } + while (m_char >= '0' && m_char <= '9') { + buffer += m_char; + next(); + } + } + + if (hasDot) { + try { + m_stream.putDouble(buffer.c_str(), buffer.length()); + } catch (std::exception const& e) { + error(std::string("Bad double: ") + e.what()); + } + } else { + try { + m_stream.putInteger(buffer.c_str(), buffer.length()); + } catch (std::exception const& e) { + error(std::string("Bad integer: ") + e.what()); + } + } + } + + // true, false, or null + void word() { + switch (m_char) { + case 't': + next(); + check('r'); + check('u'); + check('e'); + m_stream.putBoolean(true); + break; + case 'f': + next(); + check('a'); + check('l'); + check('s'); + check('e'); + m_stream.putBoolean(false); + break; + case 'n': + next(); + check('u'); + check('l'); + check('l'); + m_stream.putNull(); + break; + default: + error("unexpected character parsing word"); + return; + } + } + + CharArray parseString() { + if (m_char != '"') + error("bad string, should be '\"'"); + next(); + + CharArray str; + + while (true) { + if (m_char == '\\') { + next(); + if (m_char == 'u') { + std::string hexString; + next(); + for (int i = 0; i < 4; ++i) { + hexString.push_back(m_char); + next(); + } + char32_t codepoint = hexStringToUtf32(hexString); + if (isUtf16LeadSurrogate(codepoint)) { + check('\\'); + check('u'); + hexString.clear(); + for (int i = 0; i < 4; ++i) { + hexString.push_back(m_char); + next(); + } + codepoint = hexStringToUtf32(hexString, codepoint); + } + str += codepoint; + } else { + switch (m_char) { + case '"': + str += '"'; + break; + case '\\': + str += '\\'; + break; + case '/': + str += '/'; + break; + case 'b': + str += '\b'; + break; + case 'f': + str += '\f'; + break; + case 'n': + str += '\n'; + break; + case 'r': + str += '\r'; + break; + case 't': + str += '\t'; + break; + default: + error("bad string escape character"); + break; + } + next(); + } + } else if (m_char == '\"') { + next(); + return str; + } else if (m_char == 0) { + error("unexpected end of stream reading string!"); + } else { + str += m_char; + next(); + } + } + error("parser bug"); + return {}; + } + + // Checks current char then moves on to the next one + void check(char32_t c) { + if (m_char == 0) + error("unexpected end of stream parsing word"); + if (m_char != c) + error("unexpected character in word"); + next(); + } + + void init(InputIterator begin, InputIterator end) { + m_current = begin; + m_end = end; + m_line = 0; + m_column = 0; + + if (m_current != m_end) + m_char = *m_current; + else + m_char = 0; + } + + // Consumes next character. + void next() { + if (m_current == m_end) + return; + + if (m_char == '\n') { + ++m_line; + m_column = 0; + } else { + ++m_column; + } + ++m_current; + + if (m_current != m_end) + m_char = *m_current; + else + m_char = 0; + } + + // Will skip whitespace and comments between tokens. + void white() { + CharArray buffer; + while (m_current != m_end) { + if (m_char == '/') { + // Always consume '/' found in whitespace, because that is never valid + // JSON (other than comments) + buffer += m_char; + next(); + if (m_current != m_end && m_char == '/') { + // eat "/" + buffer += m_char; + next(); + + // Read '//' style comments up until eol/eof. + while (m_current != m_end && m_char != '\n') { + buffer += m_char; + next(); + } + } else if (m_current != m_end && m_char == '*') { + // eat "*" + buffer += m_char; + next(); + + // Read '/*' style comments up until '*/'. + while (m_current != m_end) { + if (m_char == '*') { + buffer += m_char; + next(); + if (m_char == '/') { + buffer += m_char; + next(); + break; + } + } else { + buffer += m_char; + next(); + if (m_current == m_end) + error("/* comment has no matching */"); + } + } + } else { + // The only allowed characters following / in whitespace are / and * + error("/ character in whitespace is not follwed by '/' or '*', invalid comment"); + return; + } + } else if (isSpace(m_char)) { + buffer += m_char; + next(); + } else { + if (buffer.size() != 0) + m_stream.putWhitespace(buffer.c_str(), buffer.length()); + return; + } + } + if (buffer.size() != 0) + m_stream.putWhitespace(buffer.c_str(), buffer.length()); + } + + void error(std::string msg) { + m_error = move(msg); + throw ParsingException(); + } + + bool isSpace(char32_t c) { + // Only whitespace allowed by JSON + return c == 0x20 || // space + c == 0x09 || // horizontal tab + c == 0x0a || // newline + c == 0x0d || // carriage return + c == 0xfeff; // BOM or ZWNBSP + } + + char32_t m_char; + InputIterator m_current; + InputIterator m_end; + size_t m_line; + size_t m_column; + std::string m_error; + JsonStream& m_stream; +}; + +// Write JSON through JsonStream interface. +template <typename OutputIterator> +class JsonWriter : public JsonStream { +public: + JsonWriter(OutputIterator out, unsigned pretty = 0) + : m_out(out), m_pretty(pretty) {} + + void beginObject() { + startValue(); + pushState(Object); + write('{'); + } + + void objectKey(char32_t const* s, size_t len) { + if (currentState() == ObjectElement) { + if (m_pretty) + write('\n'); + indent(); + } else { + pushState(ObjectElement); + if (m_pretty) + write('\n'); + indent(); + } + + write('"'); + char32_t c = *s; + while (c && len) { + write(c); + c = *++s; + --len; + } + write('"'); + if (m_pretty) + write(' '); + } + + void endObject() { + popState(Object); + + if (m_pretty) + write('\n'); + indent(); + write('}'); + } + + void beginArray() { + startValue(); + pushState(Array); + write('['); + } + + void endArray() { + popState(Array); + write(']'); + } + + void putString(char32_t const* s, size_t len) { + startValue(); + + write('"'); + char32_t c = *s; + while (c && (len > 0)) { + if (!isPrintable(c)) { + switch (c) { + case '"': + write('\\'); + write('"'); + break; + case '\\': + write('\\'); + write('\\'); + break; + case '\b': + write('\\'); + write('b'); + break; + case '\f': + write('\\'); + write('f'); + break; + case '\n': + write('\\'); + write('n'); + break; + case '\r': + write('\\'); + write('r'); + break; + case '\t': + write('\\'); + write('t'); + break; + default: + auto hex = hexStringFromUtf32(c); + if (hex.size() == 4) { + write('\\'); + write('u'); + for (auto c : hex) { + write(c); + } + } else if (hex.size() == 8) { + write('\\'); + write('u'); + for (auto c : hex.substr(0, 4)) { + write(c); + } + write('\\'); + write('u'); + for (auto c : hex.substr(4)) { + write(c); + } + } else { + throw UnicodeException("Internal Error: Received invalid unicode hex from hexStringFromUtf32."); + } + break; + } + } else { + write(c); + } + c = *++s; + --len; + } + write('"'); + } + + void putDouble(char32_t const* s, size_t len) { + startValue(); + for (size_t i = 0; i < len; ++i) + write(s[i]); + } + + void putInteger(char32_t const* s, size_t len) { + startValue(); + for (size_t i = 0; i < len; ++i) + write(s[i]); + } + + void putBoolean(bool b) { + startValue(); + if (b) { + write('t'); + write('r'); + write('u'); + write('e'); + } else { + write('f'); + write('a'); + write('l'); + write('s'); + write('e'); + } + } + + void putNull() { + startValue(); + write('n'); + write('u'); + write('l'); + write('l'); + } + + void putWhitespace(char32_t const* s, size_t len) { + // If m_pretty is true, extra spurious whitespace will be inserted. + for (size_t i = 0; i < len; ++i) + write(s[i]); + } + + void putColon() { + write(':'); + if (m_pretty) + write(' '); + } + + void putComma() { + write(','); + } + +private: + enum State { + Top, + Object, + ObjectElement, + Array, + ArrayElement + }; + + // Handles separating array elements if currently adding to an array + void startValue() { + if (currentState() == ArrayElement) { + if (m_pretty) + write(' '); + } else if (currentState() == Array) { + pushState(ArrayElement); + } + } + + void indent() { + for (unsigned i = 0; i < m_state.size() / 2; ++i) { + for (unsigned j = 0; j < m_pretty; ++j) { + write(' '); + } + } + } + + // Push state onto stack. + void pushState(State state) { + m_state.push_back(state); + } + + // Pop state stack down to given state. + void popState(State state) { + while (true) { + if (m_state.empty()) + return; + + State last = currentState(); + m_state.pop_back(); + if (last == state) + return; + } + } + + State currentState() { + if (m_state.empty()) + return Top; + else + return *prev(m_state.end()); + } + + void write(char32_t c) { + *m_out = c; + ++m_out; + } + + // Only chars that are unescaped according to JSON spec. + bool isPrintable(char32_t c) { + return (c >= 0x20 && c <= 0x21) || (c >= 0x23 && c <= 0x5b) || (c >= 0x5d && c <= 0x10ffff); + } + + OutputIterator m_out; + unsigned m_pretty; + std::vector<State> m_state; +}; + +} + +#endif diff --git a/source/core/StarJsonPatch.cpp b/source/core/StarJsonPatch.cpp new file mode 100644 index 0000000..e1ab6de --- /dev/null +++ b/source/core/StarJsonPatch.cpp @@ -0,0 +1,97 @@ +#include "StarJsonPatch.hpp" +#include "StarJsonPath.hpp" +#include "StarLexicalCast.hpp" + +namespace Star { + +Json jsonPatch(Json const& base, JsonArray const& patch) { + auto res = base; + try { + for (auto const& operation : patch) { + res = JsonPatching::applyOperation(res, operation); + } + return res; + } catch (JsonException const& e) { + throw JsonPatchException(strf("Could not apply patch to base. %s", e.what())); + } +} + +namespace JsonPatching { + + static const StringMap<std::function<Json(Json, Json)>> functionMap = StringMap<std::function<Json(Json, Json)>>{ + {"test", std::bind(applyTestOperation, _1, _2)}, + {"remove", std::bind(applyRemoveOperation, _1, _2)}, + {"add", std::bind(applyAddOperation, _1, _2)}, + {"replace", std::bind(applyReplaceOperation, _1, _2)}, + {"move", std::bind(applyMoveOperation, _1, _2)}, + {"copy", std::bind(applyCopyOperation, _1, _2)}, + }; + + Json applyOperation(Json const& base, Json const& op) { + try { + auto operation = op.getString("op"); + return JsonPatching::functionMap.get(operation)(base, op); + } catch (JsonException const& e) { + throw JsonPatchException(strf("Could not apply operation to base. %s", e.what())); + } catch (MapException const&) { + throw JsonPatchException(strf("Invalid operation: %s", op.getString("op"))); + } + } + + Json applyTestOperation(Json const& base, Json const& op) { + auto path = op.getString("path"); + auto value = op.opt("value"); + auto inverseTest = op.getBool("inverse", false); + + auto pointer = JsonPath::Pointer(path); + + try { + auto testValue = pointer.get(base); + if (!value) { + if (inverseTest) + throw JsonPatchTestFail(strf("Test operation failure, expected %s to be missing.", op.getString("path"))); + return base; + } + + if ((value && (testValue == *value)) ^ inverseTest) { + return base; + } + + throw JsonPatchTestFail(strf("Test operation failure, expected %s found %s.", value, testValue)); + } catch (JsonPath::TraversalException& e) { + if (inverseTest) + return base; + throw JsonPatchTestFail(strf("Test operation failure: %s", e.what())); + } + } + + Json applyRemoveOperation(Json const& base, Json const& op) { + return JsonPath::Pointer(op.getString("path")).remove(base); + } + + Json applyAddOperation(Json const& base, Json const& op) { + return JsonPath::Pointer(op.getString("path")).add(base, op.get("value")); + } + + Json applyReplaceOperation(Json const& base, Json const& op) { + auto pointer = JsonPath::Pointer(op.getString("path")); + return pointer.add(pointer.remove(base), op.get("value")); + } + + Json applyMoveOperation(Json const& base, Json const& op) { + auto fromPointer = JsonPath::Pointer(op.getString("from")); + auto toPointer = JsonPath::Pointer(op.getString("path")); + + Json value = fromPointer.get(base); + return toPointer.add(fromPointer.remove(base), value); + } + + Json applyCopyOperation(Json const& base, Json const& op) { + auto fromPointer = JsonPath::Pointer(op.getString("from")); + auto toPointer = JsonPath::Pointer(op.getString("path")); + + return toPointer.add(base, fromPointer.get(base)); + } +} + +} diff --git a/source/core/StarJsonPatch.hpp b/source/core/StarJsonPatch.hpp new file mode 100644 index 0000000..1cf689c --- /dev/null +++ b/source/core/StarJsonPatch.hpp @@ -0,0 +1,41 @@ +#ifndef STAR_JSON_PATCH_HPP +#define STAR_JSON_PATCH_HPP + +#include "StarJson.hpp" + +namespace Star { + +STAR_EXCEPTION(JsonPatchException, JsonException); +STAR_EXCEPTION(JsonPatchTestFail, StarException); + +// Applies the given RFC6902 compliant patch to the base and returns the result +// Throws JsonPatchException on patch failure. +Json jsonPatch(Json const& base, JsonArray const& patch); + +namespace JsonPatching { + // Applies the given single operation + Json applyOperation(Json const& base, Json const& op); + + // Tests for "value" at "path" + // Returns base or throws JsonPatchException + Json applyTestOperation(Json const& base, Json const& op); + + // Removes the value at "path" + Json applyRemoveOperation(Json const& base, Json const& op); + + // Adds "value" at "path" + Json applyAddOperation(Json const& base, Json const& op); + + // Replaces "path" with "value" + Json applyReplaceOperation(Json const& base, Json const& op); + + // Moves "from" to "path" + Json applyMoveOperation(Json const& base, Json const& op); + + // Copies "from" to "path" + Json applyCopyOperation(Json const& base, Json const& op); +} + +} + +#endif diff --git a/source/core/StarJsonPath.cpp b/source/core/StarJsonPath.cpp new file mode 100644 index 0000000..9eb5fae --- /dev/null +++ b/source/core/StarJsonPath.cpp @@ -0,0 +1,76 @@ +#include "StarJsonPath.hpp" + +namespace Star { + +namespace JsonPath { + + TypeHint parsePointer(String& buffer, String const& path, String::const_iterator& iterator, String::const_iterator end) { + buffer.clear(); + + if (*iterator != '/') + throw ParsingException::format("Missing leading '/' in Json pointer \"%s\"", path); + iterator++; + + while (iterator != end && *iterator != '/') { + if (*iterator == '~') { + ++iterator; + if (iterator == end) + throw ParsingException::format("Incomplete escape sequence in Json pointer \"%s\"", path); + + if (*iterator == '0') + buffer.append('~'); + else if (*iterator == '1') + buffer.append('/'); + else + throw ParsingException::format("Invalid escape sequence in Json pointer \"%s\"", path); + ++iterator; + } else + buffer.append(*iterator++); + } + + Maybe<size_t> index = maybeLexicalCast<size_t>(buffer); + if (index.isValid() || (buffer == "-" && iterator == end)) + return TypeHint::Array; + return TypeHint::Object; + } + + TypeHint parseQueryPath(String& buffer, String const& path, String::const_iterator& iterator, String::const_iterator end) { + buffer.clear(); + + if (*iterator == '.') { + throw ParsingException::format("Entry starts with '.' in query path \"%s\"", path); + + } else if (*iterator == '[') { + // Parse array number and ']' + // Consume initial '[' + ++iterator; + + while (iterator != end && *iterator >= '0' && *iterator <= '9') + buffer.append(*iterator++); + + if (iterator == end || *iterator != ']') + throw ParsingException::format("Array has no trailing ']' or has invalid character in query path \"%s\"", path); + + // Consume trailing ']' + ++iterator; + + // Consume trailing '.' + if (iterator != end && *iterator == '.') + ++iterator; + + return TypeHint::Array; + + } else { + // Parse path up to next '.' or '[' + while (iterator != end && *iterator != '.' && *iterator != '[') + buffer.append(*iterator++); + + // Consume single trailing '.' if it exists + if (iterator != end && *iterator == '.') + ++iterator; + return TypeHint::Object; + } + } +} + +} diff --git a/source/core/StarJsonPath.hpp b/source/core/StarJsonPath.hpp new file mode 100644 index 0000000..0f0efa9 --- /dev/null +++ b/source/core/StarJsonPath.hpp @@ -0,0 +1,332 @@ +#ifndef STAR_JSON_PATH_HPP +#define STAR_JSON_PATH_HPP + +#include "StarLexicalCast.hpp" +#include "StarJson.hpp" + +namespace Star { + +namespace JsonPath { + enum class TypeHint { + Array, + Object + }; + + typedef function<TypeHint(String&, String const&, String::const_iterator&, String::const_iterator)> PathParser; + + STAR_EXCEPTION(ParsingException, JsonException); + STAR_EXCEPTION(TraversalException, JsonException); + + // Parses RFC 6901 JSON Pointers, e.g. /foo/bar/4/baz + TypeHint parsePointer(String& outputBuffer, String const& path, String::const_iterator& iterator, String::const_iterator end); + + // Parses JavaScript-like paths, e.g. foo.bar[4].baz + TypeHint parseQueryPath(String& outputBuffer, String const& path, String::const_iterator& iterator, String::const_iterator end); + + // Retrieves the portion of the Json document referred to by the given path. + template <typename Jsonlike> + Jsonlike pathGet(Jsonlike base, PathParser parser, String const& path); + + // Find a given portion of the JSON document, if it exists. Instead of + // throwing a TraversalException if a portion of the path is invalid, simply + // returns nothing. + template <typename Jsonlike> + Maybe<Jsonlike> pathFind(Jsonlike base, PathParser parser, String const& path); + + template <typename Jsonlike> + using JsonOp = function<Jsonlike(Jsonlike const&, Maybe<String> const&)>; + + // Applies a function to the portion of the Json document referred to by the + // given path, returning the resulting new document. If the end of the path + // doesn't exist, the JsonOp is called with None, and its result will be + // inserted into the document. If the path already existed and the JsonOp + // returns None, it is erased. This is not as well-optimized as pathGet, but + // also not on the critical path for anything. + template <typename Jsonlike> + Jsonlike pathApply(Jsonlike const& base, PathParser parser, String const& path, JsonOp<Jsonlike> op); + + // Sets a value on a Json document at the location referred to by path, + // returning the resulting new document. + template <typename Jsonlike> + Jsonlike pathSet(Jsonlike const& base, PathParser parser, String const& path, Jsonlike const& value); + + // Erases the location referred to by the path from the document + template <typename Jsonlike> + Jsonlike pathRemove(Jsonlike const& base, PathParser parser, String const& path); + + // Performs RFC6902 (JSON Patching) add operation. Inserts into arrays, or + // appends if the last path segment is "-". On objects, does the same as + // pathSet. + template <typename Jsonlike> + Jsonlike pathAdd(Jsonlike const& base, PathParser parser, String const& path, Jsonlike const& value); + + template <typename Jsonlike> + using EmptyPathOp = function<Jsonlike(Jsonlike const&)>; + template <typename Jsonlike> + using ObjectOp = function<Jsonlike(Jsonlike const&, String const&)>; + template <typename Jsonlike> + using ArrayOp = function<Jsonlike(Jsonlike const&, Maybe<size_t>)>; + + template <typename Jsonlike> + JsonOp<Jsonlike> genericObjectArrayOp(String path, EmptyPathOp<Jsonlike> emptyPathOp, ObjectOp<Jsonlike> objectOp, ArrayOp<Jsonlike> arrayOp); + + STAR_CLASS(Path); + STAR_CLASS(Pointer); + STAR_CLASS(QueryPath); + + class Path { + public: + Path(PathParser parser, String const& path) : m_parser(parser), m_path(path) {} + + template <typename Jsonlike> + Jsonlike get(Jsonlike const& base) { + return pathGet(base, m_parser, m_path); + } + + template <typename Jsonlike> + Jsonlike apply(Jsonlike const& base, JsonOp<Jsonlike> op) { + return pathApply(base, m_parser, m_path, op); + } + + template <typename Jsonlike> + Jsonlike apply(Jsonlike const& base, + EmptyPathOp<Jsonlike> emptyPathOp, + ObjectOp<Jsonlike> objectOp, + ArrayOp<Jsonlike> arrayOp) { + JsonOp<Jsonlike> combinedOp = genericObjectArrayOp(m_path, emptyPathOp, objectOp, arrayOp); + return pathApply(base, m_parser, m_path, combinedOp); + } + + template <typename Jsonlike> + Jsonlike set(Jsonlike const& base, Jsonlike const& value) { + return pathSet(base, m_parser, m_path, value); + } + + template <typename Jsonlike> + Jsonlike remove(Jsonlike const& base) { + return pathRemove(base, m_parser, m_path); + } + + template <typename Jsonlike> + Jsonlike add(Jsonlike const& base, Jsonlike const& value) { + return pathAdd(base, m_parser, m_path, value); + } + + String const& path() const { + return m_path; + } + + private: + PathParser m_parser; + String m_path; + }; + + class Pointer : public Path { + public: + Pointer(String const& path) : Path(parsePointer, path) {} + }; + + class QueryPath : public Path { + public: + QueryPath(String const& path) : Path(parseQueryPath, path) {} + }; + + template <typename Jsonlike> + Jsonlike pathGet(Jsonlike value, PathParser parser, String const& path) { + String buffer; + buffer.reserve(path.size()); + + auto pos = path.begin(); + + while (pos != path.end()) { + parser(buffer, path, pos, path.end()); + + if (value.type() == Json::Type::Array) { + if (buffer == "-") + throw TraversalException::format("Tried to get key '%s' in non-object type in pathGet(\"%s\")", buffer, path); + Maybe<size_t> i = maybeLexicalCast<size_t>(buffer); + if (!i) + throw TraversalException::format("Cannot parse '%s' as index in pathGet(\"%s\")", buffer, path); + + if (*i < value.size()) + value = value.get(*i); + else + throw TraversalException::format("Index %s out of range in pathGet(\"%s\")", buffer, path); + + } else if (value.type() == Json::Type::Object) { + if (value.contains(buffer)) + value = value.get(buffer); + else + throw TraversalException::format("No such key '%s' in pathGet(\"%s\")", buffer, path); + + } else { + throw TraversalException::format("Tried to get key '%s' in non-object type in pathGet(\"%s\")", buffer, path); + } + } + return value; + } + + template <typename Jsonlike> + Maybe<Jsonlike> pathFind(Jsonlike value, PathParser parser, String const& path) { + String buffer; + buffer.reserve(path.size()); + + auto pos = path.begin(); + + while (pos != path.end()) { + parser(buffer, path, pos, path.end()); + + if (value.type() == Json::Type::Array) { + if (buffer == "-") + return {}; + + Maybe<size_t> i = maybeLexicalCast<size_t>(buffer); + if (i && *i < value.size()) + value = value.get(*i); + else + return {}; + + } else if (value.type() == Json::Type::Object) { + if (value.contains(buffer)) + value = value.get(buffer); + else + return {}; + + } else { + return {}; + } + } + return value; + } + + template <typename Jsonlike> + Jsonlike pathApply(String& buffer, + Jsonlike const& value, + PathParser parser, + String const& path, + String::const_iterator const current, + JsonOp<Jsonlike> op) { + if (current == path.end()) + return op(value, {}); + + String::const_iterator iterator = current; + parser(buffer, path, iterator, path.end()); + + if (value.type() == Json::Type::Array) { + if (iterator == path.end()) { + return op(value, buffer); + } else { + Maybe<size_t> i = maybeLexicalCast<size_t>(buffer); + if (!i) + throw TraversalException::format("Cannot parse '%s' as index in pathApply(\"%s\")", buffer, path); + + if (*i >= value.size()) + throw TraversalException::format("Index %s out of range in pathApply(\"%s\")", buffer, path); + + return value.set(*i, pathApply(buffer, value.get(*i), parser, path, iterator, op)); + } + + } else if (value.type() == Json::Type::Object) { + if (iterator == path.end()) { + return op(value, buffer); + + } else { + if (!value.contains(buffer)) + throw TraversalException::format("No such key '%s' in pathApply(\"%s\")", buffer, path); + + Jsonlike newChild = pathApply(buffer, value.get(buffer), parser, path, iterator, op); + iterator = current; + // pathApply just mutated buffer. Recover the current path component: + parser(buffer, path, iterator, path.end()); + return value.set(buffer, newChild); + } + + } else { + throw TraversalException::format("Tried to get key '%s' in non-object type in pathApply(\"%s\")", buffer, path); + } + } + + template <typename Jsonlike> + Jsonlike pathApply(Jsonlike const& base, PathParser parser, String const& path, JsonOp<Jsonlike> op) { + String buffer; + return pathApply(buffer, base, parser, path, path.begin(), op); + } + + template <typename Jsonlike> + JsonOp<Jsonlike> genericObjectArrayOp(String path, EmptyPathOp<Jsonlike> emptyPathOp, ObjectOp<Jsonlike> objectOp, ArrayOp<Jsonlike> arrayOp) { + return [=](Jsonlike const& parent, Maybe<String> const& key) -> Jsonlike { + if (key.isNothing()) + return emptyPathOp(parent); + if (parent.type() == Json::Type::Array) { + if (*key == "-") + return arrayOp(parent, {}); + Maybe<size_t> i = maybeLexicalCast<size_t>(*key); + if (!i) + throw TraversalException::format("Cannot parse '%s' as index in Json path \"%s\"", *key, path); + if (i && *i > parent.size()) + throw TraversalException::format("Index %s out of range in Json path \"%s\"", *key, path); + if (i && *i == parent.size()) + i = {}; + return arrayOp(parent, i); + } else if (parent.type() == Json::Type::Object) { + return objectOp(parent, *key); + } else { + throw TraversalException::format("Tried to set key '%s' in non-object type in pathSet(\"%s\")", *key, path); + } + }; + } + + template <typename Jsonlike> + Jsonlike pathSet(Jsonlike const& base, PathParser parser, String const& path, Jsonlike const& value) { + EmptyPathOp<Jsonlike> emptyPathOp = [&value](Jsonlike const&) { + return value; + }; + ObjectOp<Jsonlike> objectOp = [&value](Jsonlike const& object, String const& key) { + return object.set(key, value); + }; + ArrayOp<Jsonlike> arrayOp = [&value](Jsonlike const& array, Maybe<size_t> i) { + if (i.isValid()) + return array.set(*i, value); + return array.append(value); + }; + return pathApply(base, parser, path, genericObjectArrayOp(path, emptyPathOp, objectOp, arrayOp)); + } + + template <typename Jsonlike> + Jsonlike pathRemove(Jsonlike const& base, PathParser parser, String const& path) { + EmptyPathOp<Jsonlike> emptyPathOp = [](Jsonlike const&) { return Json{}; }; + ObjectOp<Jsonlike> objectOp = [](Jsonlike const& object, String const& key) { + if (!object.contains(key)) + throw TraversalException::format("Could not find \"%s\" to remove", key); + return object.eraseKey(key); + }; + ArrayOp<Jsonlike> arrayOp = [](Jsonlike const& array, Maybe<size_t> i) { + if (i.isValid()) + return array.eraseIndex(*i); + throw TraversalException("Could not remove element after end of array"); + }; + return pathApply(base, parser, path, genericObjectArrayOp(path, emptyPathOp, objectOp, arrayOp)); + } + + template <typename Jsonlike> + Jsonlike pathAdd(Jsonlike const& base, PathParser parser, String const& path, Jsonlike const& value) { + EmptyPathOp<Jsonlike> emptyPathOp = [&value](Jsonlike const& document) { + if (document.type() == Json::Type::Null) + return value; + throw JsonException("Cannot add a value to the entire document, it is not empty."); + }; + ObjectOp<Jsonlike> objectOp = [&value](Jsonlike const& object, String const& key) { + return object.set(key, value); + }; + ArrayOp<Jsonlike> arrayOp = [&value](Jsonlike const& array, Maybe<size_t> i) { + if (i.isValid()) + return array.insert(*i, value); + return array.append(value); + }; + return pathApply(base, parser, path, genericObjectArrayOp(path, emptyPathOp, objectOp, arrayOp)); + } +} + +} + +#endif diff --git a/source/core/StarJsonRpc.cpp b/source/core/StarJsonRpc.cpp new file mode 100644 index 0000000..dcdc16e --- /dev/null +++ b/source/core/StarJsonRpc.cpp @@ -0,0 +1,115 @@ +#include "StarJsonRpc.hpp" +#include "StarDataStreamDevices.hpp" +#include "StarLogging.hpp" + +namespace Star { + +JsonRpcInterface::~JsonRpcInterface() {} + +JsonRpc::JsonRpc() { + m_requestId = 0; +} + +void JsonRpc::registerHandler(String const& handler, JsonRpcRemoteFunction func) { + if (m_handlers.contains(handler)) + throw JsonRpcException(strf("Handler by that name already exists '%s'", handler)); + m_handlers.add(move(handler), move(func)); +} + +void JsonRpc::registerHandlers(JsonRpcHandlers const& handlers) { + for (auto const& pair : handlers) + registerHandler(pair.first, pair.second); +} + +void JsonRpc::removeHandler(String const& handler) { + if (!m_handlers.contains(handler)) + throw JsonRpcException(strf("No such handler by the name '%s'", handler)); + + m_handlers.remove(handler); +} + +void JsonRpc::clearHandlers() { + m_handlers.clear(); +} + +RpcPromise<Json> JsonRpc::invokeRemote(String const& handler, Json const& arguments) { + uint64_t id = m_requestId++; + JsonObject request; + m_pending.append(JsonObject{ + {"command", "request"}, + {"id", id}, + {"handler", handler}, + {"arguments", arguments} + }); + + auto pair = RpcPromise<Json>::createPair(); + m_pendingResponse.add(id, pair.second); + + return pair.first; +} + +bool JsonRpc::sendPending() const { + return !m_pending.empty(); +} + +ByteArray JsonRpc::send() { + if (m_pending.empty()) + return {}; + + DataStreamBuffer buffer; + buffer.writeContainer(m_pending); + m_pending.clear(); + + return buffer.takeData(); +} + +void JsonRpc::receive(ByteArray const& inbuffer) { + if (inbuffer.empty()) + return; + + DataStreamBuffer buffer(inbuffer); + List<Json> inbound; + buffer.readContainer(inbound); + + for (auto request : inbound) { + if (request.get("command") == "request") { + try { + auto handlerName = request.getString("handler"); + if (!m_handlers.contains(handlerName)) + throw JsonRpcException(strf("Unknown handler '%s'", handlerName)); + m_pending.append(JsonObject{ + {"command", "response"}, + {"id", request.get("id")}, + {"result", m_handlers[handlerName](request.get("arguments"))} + }); + } catch (std::exception& e) { + Logger::error("Exception while handling variant rpc request handler call. %s", outputException(e, false)); + JsonObject response; + response["command"] = "fail"; + response["id"] = request.get("id"); + m_pending.append(JsonObject{ + {"command", "fail"}, + {"id", request.get("id")} + }); + } + + } else if (request.get("command") == "response") { + try { + auto responseHandler = m_pendingResponse.take(request.getUInt("id")); + responseHandler.fulfill(request.get("result")); + } catch (std::exception& e) { + Logger::error("Exception while handling variant rpc response handler call. %s", outputException(e, true)); + } + + } else if (request.get("command") == "fail") { + try { + auto responseHandler = m_pendingResponse.take(request.getUInt("id")); + responseHandler.fulfill({}); + } catch (std::exception& e) { + Logger::error("Exception while handling variant rpc failure handler call. %s", outputException(e, true)); + } + } + } +} + +} diff --git a/source/core/StarJsonRpc.hpp b/source/core/StarJsonRpc.hpp new file mode 100644 index 0000000..651e717 --- /dev/null +++ b/source/core/StarJsonRpc.hpp @@ -0,0 +1,54 @@ +#ifndef STAR_JSON_RPC_HPP +#define STAR_JSON_RPC_HPP + +#include "StarJson.hpp" +#include "StarByteArray.hpp" +#include "StarRpcPromise.hpp" + +namespace Star { + +STAR_CLASS(JsonRpcInterface); +STAR_CLASS(JsonRpc); + +STAR_EXCEPTION(JsonRpcException, StarException); + +typedef function<Json(Json const&)> JsonRpcRemoteFunction; + +typedef StringMap<JsonRpcRemoteFunction> JsonRpcHandlers; + +// Simple interface to just the method invocation part of JsonRpc. +class JsonRpcInterface { +public: + virtual ~JsonRpcInterface(); + virtual RpcPromise<Json> invokeRemote(String const& handler, Json const& arguments) = 0; +}; + +// Simple class to handle remote methods based on Json types. Does not +// handle any of the network details, simply turns rpc calls into ByteArray +// messages to be sent and received. +class JsonRpc : public JsonRpcInterface { +public: + JsonRpc(); + + void registerHandler(String const& handler, JsonRpcRemoteFunction func); + void registerHandlers(JsonRpcHandlers const& handlers); + + void removeHandler(String const& handler); + void clearHandlers(); + + RpcPromise<Json> invokeRemote(String const& handler, Json const& arguments) override; + + bool sendPending() const; + ByteArray send(); + void receive(ByteArray const& inbuffer); + +private: + JsonRpcHandlers m_handlers; + Map<uint64_t, RpcPromiseKeeper<Json>> m_pendingResponse; + List<Json> m_pending; + uint64_t m_requestId; +}; + +} + +#endif diff --git a/source/core/StarLexicalCast.hpp b/source/core/StarLexicalCast.hpp new file mode 100644 index 0000000..bee2692 --- /dev/null +++ b/source/core/StarLexicalCast.hpp @@ -0,0 +1,74 @@ +#ifndef STAR_LEXICAL_CAST_HPP +#define STAR_LEXICAL_CAST_HPP + +#include "StarString.hpp" +#include "StarMaybe.hpp" + +#include <sstream> +#include <locale> + +namespace Star { + +STAR_EXCEPTION(BadLexicalCast, StarException); + +// Very simple basic lexical cast using stream input. Always operates in the +// "C" locale. +template <typename Type> +Maybe<Type> maybeLexicalCast(std::string const& s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + Type result; + std::istringstream stream(s); + stream.flags(flags); + stream.imbue(std::locale::classic()); + + if (!(stream >> result)) + return {}; + + // Confirm that we read everything out of the stream + char ch; + if (stream >> ch) + return {}; + + return result; +} + +template <typename Type> +Maybe<Type> maybeLexicalCast(char const* s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + return maybeLexicalCast<Type>(std::string(s), flags); +} + +template <typename Type> +Maybe<Type> maybeLexicalCast(String const& s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + return maybeLexicalCast<Type>(s.utf8(), flags); +} + +template <typename Type> +Type lexicalCast(std::string const& s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + auto m = maybeLexicalCast<Type>(s, flags); + if (m) + return m.take(); + else + throw BadLexicalCast(); +} + +template <typename Type> +Type lexicalCast(char const* s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + return lexicalCast<Type>(std::string(s), flags); +} + +template <typename Type> +Type lexicalCast(String const& s, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + return lexicalCast<Type>(s.utf8(), flags); +} + +template <class Type> +std::string toString(Type const& t, std::ios_base::fmtflags flags = std::ios_base::boolalpha) { + std::stringstream ss; + ss.flags(flags); + ss.imbue(std::locale::classic()); + ss << t; + return ss.str(); +} + +} + +#endif diff --git a/source/core/StarLine.hpp b/source/core/StarLine.hpp new file mode 100644 index 0000000..ccb3e26 --- /dev/null +++ b/source/core/StarLine.hpp @@ -0,0 +1,289 @@ +#ifndef STAR_LINE_HPP +#define STAR_LINE_HPP + +#include "StarMatrix3.hpp" + +namespace Star { + +template <typename T, size_t N> +class Line { +public: + typedef Vector<T, N> VectorType; + + struct IntersectResult { + // Whether or not the two objects intersect + bool intersects; + // Where the intersection is (minimum value if intersection occurs in more + // than one point.) + VectorType point; + // T value where intersection occurs, 0 is min, 1 is max + T t; + // Whether or not the two lines, if they were infinite lines, are the exact + // same line + bool coincides; + // Whether or not the intersection is a glancing one, meaning the other + // line isn't actually skewered, it's just barely touching Coincidental + // lines are always glancing intersections. + bool glances; + }; + + Line() {} + + template <typename T2> + explicit Line(Line<T2, N> const& line) + : m_min(line.min()), m_max(line.max()) {} + + Line(VectorType const& a, VectorType const& b) + : m_min(a), m_max(b) {} + + VectorType direction() const { + return diff().normalized(); + } + + T length() const { + return diff().magnitude(); + } + + T angle() const { + return diff().angle(); + } + + VectorType eval(T t) const { + return m_min + diff() * t; + } + + VectorType diff() const { + return (m_max - m_min); + } + + VectorType center() const { + return (m_min + m_max) / 2; + } + + void setCenter(VectorType c) { + return translate(c - center()); + } + + VectorType& min() { + return m_min; + } + + VectorType& max() { + return m_max; + } + + VectorType const& min() const { + return m_min; + } + + VectorType const& max() const { + return m_max; + } + + VectorType midpoint() const { + return (m_max + m_min) / 2; + } + + bool makePositive() { + bool changed = false; + for (unsigned i = 0; i < N; i++) { + if (m_min[i] < m_max[i]) { + break; + } else if (m_min[i] > m_max[i]) { + std::swap(m_min, m_max); + changed = true; + break; + } + } + return changed; + } + + void reverse() { + std::swap(m_min, m_max); + } + + Line reversed() { + return Line(m_max, m_min); + } + + void translate(VectorType const& trans) { + m_min += trans; + m_max += trans; + } + + Line translated(VectorType const& trans) { + return Line(m_min + trans, m_max + trans); + } + + void scale(VectorType const& s, VectorType const& c = VectorType()) { + m_min = vmult(m_min - c, s) + c; + m_max = vmult(m_max - c, s) + c; + } + + void scale(T s, VectorType const& c = VectorType()) { + scale(VectorType::filled(s), c); + } + + bool operator==(Line const& rhs) const { + return tie(m_min, m_max) == tie(rhs.m_min, rhs.m_max); + } + + bool operator<(Line const& rhs) const { + return tie(m_min, m_max) < tie(rhs.m_min, rhs.m_max); + } + + // Line2 + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, IntersectResult>::type intersection( + Line const& line2, bool infinite = false) const { + Line l1 = *this; + Line l2 = line2; + // Warning to others, do not make the lines positive, because points of + // intersection for coincidental lines are determined by the first point + // And makePositive() changes the order of points. This causes headaches + // later on + // l1.makePositive(); + // l2.makePositive(); + VectorType a = l1.min(); + VectorType b = l1.max(); + VectorType c = l2.min(); + VectorType d = l2.max(); + + VectorType ab = diff(); + VectorType cd = l2.diff(); + + T denom = ab ^ cd; + T xNumer = (a ^ b) * cd[0] - (c ^ d) * ab[0]; + T yNumer = (a ^ b) * cd[1] - (c ^ d) * ab[1]; + + IntersectResult isect; + if (nearZero(denom)) { // the lines are parallel unless + if (nearZero(xNumer) && nearZero(yNumer)) { // the lines are coincidental + isect.intersects = infinite || (a >= c && a <= d) || (c >= a && c <= b); + if (isect.intersects) { + // returns the minimum intersection point + if (infinite) { + isect.point = VectorType::filled(-std::numeric_limits<T>::max()); + } else { + isect.point = a < c ? c : a; + } + } + if (a < c) { + if (c[0] != a[0]) { + isect.t = (c[0] - a[0]) / ab[0]; + } else { + isect.t = (c[1] - a[1]) / ab[1]; + } + } else if (a > d) { + if (d[0] != a[0]) { + isect.t = (d[0] - a[0]) / ab[0]; + } else { + isect.t = (d[1] - a[1]) / ab[1]; + } + } else { + isect.t = 0; + } + isect.coincides = true; + isect.glances = isect.intersects; + } else { + isect.intersects = false; + isect.t = std::numeric_limits<T>::max(); + isect.point = VectorType(); + isect.coincides = false; + isect.glances = false; + } + } else { + T ta = ((c - a) ^ cd) / denom; + T tb = ((c - a) ^ ab) / denom; + + isect.intersects = infinite || (ta >= 0 && ta <= 1.0 && tb >= 0 && tb <= 1.0); + isect.t = ta; + isect.point = VectorType(ta * (b[0] - a[0]) + a[0], ta * (b[1] - a[1]) + a[1]); + isect.coincides = false; + isect.glances = !infinite && isect.intersects && (nearZero(ta) || nearEqual(ta, 1.0f) || nearZero(tb) || nearEqual(tb, 1.0f)); + } + return isect; + } + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, bool>::type intersects(Line const& l2, bool infinite = false) const { + return intersection(l2, infinite).intersects; + } + + // Returns t value for closest point on the line. t value is *not* clamped + // from 0.0 to 1.0 + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, T>::type lineProjection(VectorType const& l2) const { + VectorType d = diff(); + return ((l2[0] - min()[0]) * d[0] + (l2[1] - min()[1]) * d[1]) / d.magnitudeSquared(); + } + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, T>::type distanceTo(VectorType const& l, bool infinite = false) const { + auto t = lineProjection(l); + if (!infinite) + t = clamp<T>(t, 0, 1); + return vmag(l - eval(t)); + } + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, void>::type rotate( + T angle, VectorType const& rotationCenter = VectorType()) { + auto rotMatrix = Mat3F::rotation(angle, rotationCenter); + min() = rotMatrix.transformVec2(min()); + max() = rotMatrix.transformVec2(max()); + } + + template <typename T2, size_t P = N> + typename std::enable_if<P == 2 && N == P, void>::type transform(Matrix3<T2> const& transform) { + min() = transform.transformVec2(min()); + max() = transform.transformVec2(max()); + } + + template <typename T2, size_t P = N> + typename std::enable_if<P == 2 && N == P, Line>::type transformed(Matrix3<T2> const& transform) const { + return Line(transform.transformVec2(min()), transform.transformVec2(max())); + } + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, void>::type flipHorizontal(T horizontalPos) { + m_min[0] = horizontalPos + (horizontalPos - m_min[0]); + m_max[0] = horizontalPos + (horizontalPos - m_max[0]); + } + + template <size_t P = N> + typename std::enable_if<P == 2 && N == P, void>::type flipVertical(T verticalPos) { + m_min[1] = verticalPos + (verticalPos - m_min[1]); + m_max[1] = verticalPos + (verticalPos - m_max[1]); + } + +private: + VectorType m_min; + VectorType m_max; +}; + +typedef Line<float, 2> Line2F; +typedef Line<double, 2> Line2D; +typedef Line<int, 2> Line2I; + +template <typename T, size_t N> +std::ostream& operator<<(std::ostream& os, Line<T, N> const& l) { + os << '[' << l.min() << ", " << l.max() << ']'; + return os; +} + +template <typename T, size_t N> +struct hash<Line<T, N>> { + size_t operator()(Line<T, N> const& line) const { + size_t hashval = 0; + hashCombine(hashval, vectorHasher(line.min())); + hashCombine(hashval, vectorHasher(line.max())); + return hashval; + } + Star::hash<typename Line<T, N>::VectorType> vectorHasher; +}; + +} + +#endif diff --git a/source/core/StarList.hpp b/source/core/StarList.hpp new file mode 100644 index 0000000..bfedf54 --- /dev/null +++ b/source/core/StarList.hpp @@ -0,0 +1,1138 @@ +#ifndef STAR_LIST_HPP +#define STAR_LIST_HPP + +#include <vector> +#include <deque> +#include <list> + +#include "StarException.hpp" +#include "StarStaticVector.hpp" +#include "StarSmallVector.hpp" +#include "StarPythonic.hpp" +#include "StarMaybe.hpp" +#include "StarFormat.hpp" + +namespace Star { + +template <typename BaseList> +class ListMixin : public BaseList { +public: + typedef BaseList Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + ListMixin(); + ListMixin(Base const& list); + ListMixin(Base&& list); + ListMixin(value_type const* p, size_t count); + template <typename InputIterator> + ListMixin(InputIterator beg, InputIterator end); + explicit ListMixin(size_t len, const_reference s1 = value_type()); + ListMixin(initializer_list<value_type> list); + + void append(value_type e); + + template <typename Container> + void appendAll(Container&& list); + + template <class... Args> + reference emplaceAppend(Args&&... args); + + reference first(); + const_reference first() const; + + reference last(); + const_reference last() const; + + Maybe<value_type> maybeFirst(); + Maybe<value_type> maybeLast(); + + void removeLast(); + value_type takeLast(); + + Maybe<value_type> maybeTakeLast(); + + // Limit the size of the list by removing elements from the back until the + // size is the maximumSize or less. + void limitSizeBack(size_t maximumSize); + + size_t count() const; + + bool contains(const_reference e) const; + // Remove all equal to element, returns number removed. + size_t remove(const_reference e); + + template <typename Filter> + void filter(Filter&& filter); + + template <typename Comparator> + void insertSorted(value_type e, Comparator&& comparator); + void insertSorted(value_type e); + + // Returns true if this *sorted* list contains the given element. + template <typename Comparator> + bool containsSorted(value_type const& e, Comparator&& comparator); + bool containsSorted(value_type e); + + template <typename Function> + void exec(Function&& function); + + template <typename Function> + void exec(Function&& function) const; + + template <typename Function> + void transform(Function&& function); + + template <typename Function> + bool any(Function&& function) const; + bool any() const; + + template <typename Function> + bool all(Function&& function) const; + bool all() const; +}; + +template <typename List> +class ListHasher { +public: + size_t operator()(List const& l) const; + +private: + hash<typename List::value_type> elemHasher; +}; + +template <typename BaseList> +class RandomAccessListMixin : public BaseList { +public: + typedef BaseList Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + using Base::Base; + + template <typename Comparator> + void sort(Comparator&& comparator); + void sort(); + + void reverse(); + + // Returns first index of given element, NPos if not found. + size_t indexOf(const_reference e, size_t from = 0) const; + // Returns last index of given element, NPos if not found. + size_t lastIndexOf(const_reference e, size_t til = NPos) const; + + const_reference at(size_t n) const; + reference at(size_t n); + + const_reference operator[](size_t n) const; + reference operator[](size_t n); + + // Does not throw if n is beyond end of list, instead returns def + value_type get(size_t n, value_type def = value_type()) const; + + value_type takeAt(size_t i); + + // Same as at, but wraps around back to the beginning + // (throws if list is empty) + const_reference wrap(size_t n) const; + reference wrap(size_t n); + + // Does not throw if list is empty + value_type wrap(size_t n, value_type def) const; + + void eraseAt(size_t index); + // Erases region from begin to end, not including end. + void eraseAt(size_t begin, size_t end); + + void insertAt(size_t pos, value_type e); + + template <typename Container> + void insertAllAt(size_t pos, Container const& l); + + // Ensures that list is large enough to hold pos elements. + void set(size_t pos, value_type e); + + void swap(size_t i, size_t j); + // same as insert(to, takeAt(from)) + void move(size_t from, size_t to); +}; + +template <typename BaseList> +class FrontModifyingListMixin : public BaseList { +public: + typedef BaseList Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + using Base::Base; + + void prepend(value_type e); + + template <typename Container> + void prependAll(Container&& list); + + template <class... Args> + reference emplacePrepend(Args&&... args); + + void removeFirst(); + value_type takeFirst(); + + // Limit the size of the list by removing elements from the front until the + // size is the maximumSize or less. + void limitSizeFront(size_t maximumSize); +}; + +template <typename Element, typename Allocator = std::allocator<Element>> +class List : public RandomAccessListMixin<ListMixin<std::vector<Element, Allocator>>> { +public: + typedef RandomAccessListMixin<ListMixin<std::vector<Element, Allocator>>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + template <typename Container> + static List from(Container const& c); + + using Base::Base; + + // Pointer to contiguous storage, returns nullptr if empty + value_type* ptr(); + value_type const* ptr() const; + + List slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + template <typename Filter> + List filtered(Filter&& filter) const; + + template <typename Comparator> + List sorted(Comparator&& comparator) const; + List sorted() const; + + template <typename Function> + auto transformed(Function&& function); + + template <typename Function> + auto transformed(Function&& function) const; +}; + +template <typename Element, typename Allocator> +struct hash<List<Element, Allocator>> : public ListHasher<List<Element, Allocator>> {}; + +template <typename Element, size_t MaxSize> +class StaticList : public RandomAccessListMixin<ListMixin<StaticVector<Element, MaxSize>>> { +public: + typedef RandomAccessListMixin<ListMixin<StaticVector<Element, MaxSize>>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + template <typename Container> + static StaticList from(Container const& c); + + using Base::Base; + + StaticList slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + template <typename Filter> + StaticList filtered(Filter&& filter) const; + + template <typename Comparator> + StaticList sorted(Comparator&& comparator) const; + StaticList sorted() const; + + template <typename Function> + auto transformed(Function&& function); + + template <typename Function> + auto transformed(Function&& function) const; +}; + +template <typename Element, size_t MaxStackSize> +struct hash<StaticList<Element, MaxStackSize>> : public ListHasher<StaticList<Element, MaxStackSize>> {}; + +template <typename Element, size_t MaxStackSize> +class SmallList : public RandomAccessListMixin<ListMixin<SmallVector<Element, MaxStackSize>>> { +public: + typedef RandomAccessListMixin<ListMixin<SmallVector<Element, MaxStackSize>>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + template <typename Container> + static SmallList from(Container const& c); + + using Base::Base; + + SmallList slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + template <typename Filter> + SmallList filtered(Filter&& filter) const; + + template <typename Comparator> + SmallList sorted(Comparator&& comparator) const; + SmallList sorted() const; + + template <typename Function> + auto transformed(Function&& function); + + template <typename Function> + auto transformed(Function&& function) const; +}; + +template <typename Element, size_t MaxStackSize> +struct hash<SmallList<Element, MaxStackSize>> : public ListHasher<SmallList<Element, MaxStackSize>> {}; + +template <typename Element, typename Allocator = std::allocator<Element>> +class Deque : public FrontModifyingListMixin<RandomAccessListMixin<ListMixin<std::deque<Element, Allocator>>>> { +public: + typedef FrontModifyingListMixin<RandomAccessListMixin<ListMixin<std::deque<Element, Allocator>>>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + template <typename Container> + static Deque from(Container const& c); + + using Base::Base; + + Deque slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + template <typename Filter> + Deque filtered(Filter&& filter) const; + + template <typename Comparator> + Deque sorted(Comparator&& comparator) const; + Deque sorted() const; + + template <typename Function> + auto transformed(Function&& function); + + template <typename Function> + auto transformed(Function&& function) const; +}; + +template <typename Element, typename Allocator> +struct hash<Deque<Element, Allocator>> : public ListHasher<Deque<Element, Allocator>> {}; + +template <typename Element, typename Allocator = std::allocator<Element>> +class LinkedList : public FrontModifyingListMixin<ListMixin<std::list<Element, Allocator>>> { +public: + typedef FrontModifyingListMixin<ListMixin<std::list<Element, Allocator>>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::value_type value_type; + typedef typename Base::reference reference; + typedef typename Base::const_reference const_reference; + + template <typename Container> + static LinkedList from(Container const& c); + + using Base::Base; + + void appendAll(LinkedList list); + void prependAll(LinkedList list); + + template <typename Container> + void appendAll(Container&& list); + template <typename Container> + void prependAll(Container&& list); + + template <typename Filter> + LinkedList filtered(Filter&& filter) const; + + template <typename Comparator> + LinkedList sorted(Comparator&& comparator) const; + LinkedList sorted() const; + + template <typename Function> + auto transformed(Function&& function); + + template <typename Function> + auto transformed(Function&& function) const; +}; + +template <typename Element, typename Allocator> +struct hash<LinkedList<Element, Allocator>> : public ListHasher<LinkedList<Element, Allocator>> {}; + +template <typename BaseList> +std::ostream& operator<<(std::ostream& os, ListMixin<BaseList> const& list); + +template <typename... Containers> +struct ListZipTypes { + typedef tuple<typename std::decay<Containers>::type::value_type...> Tuple; + typedef List<Tuple> Result; +}; + +template <typename... Containers> +typename ListZipTypes<Containers...>::Result zip(Containers&&... args); + +template <typename Container> +struct ListEnumerateTypes { + typedef pair<typename std::decay<Container>::type::value_type, size_t> Pair; + typedef List<Pair> Result; +}; + +template <typename Container> +typename ListEnumerateTypes<Container>::Result enumerate(Container&& container); + +template <typename BaseList> +ListMixin<BaseList>::ListMixin() + : Base() {} + +template <typename BaseList> +ListMixin<BaseList>::ListMixin(Base const& list) + : Base(list) {} + +template <typename BaseList> +ListMixin<BaseList>::ListMixin(Base&& list) + : Base(std::move(list)) {} + +template <typename BaseList> +ListMixin<BaseList>::ListMixin(size_t len, const_reference s1) + : Base(len, s1) {} + +template <typename BaseList> +ListMixin<BaseList>::ListMixin(value_type const* p, size_t count) + : Base(p, p + count) {} + +template <typename BaseList> +template <typename InputIterator> +ListMixin<BaseList>::ListMixin(InputIterator beg, InputIterator end) + : Base(beg, end) {} + +template <typename BaseList> +ListMixin<BaseList>::ListMixin(initializer_list<value_type> list) { + // In case underlying class type doesn't support initializer_list + for (auto& e : list) + append(std::move(e)); +} + +template <typename BaseList> +void ListMixin<BaseList>::append(value_type e) { + Base::push_back(std::move(e)); +} + +template <typename BaseList> +template <typename Container> +void ListMixin<BaseList>::appendAll(Container&& list) { + for (auto& e : list) { + if (std::is_rvalue_reference<Container&&>::value) + Base::push_back(std::move(e)); + else + Base::push_back(e); + } +} + +template <typename BaseList> +template <class... Args> +auto ListMixin<BaseList>::emplaceAppend(Args&&... args) -> reference { + Base::emplace_back(forward<Args>(args)...); + return *prev(Base::end()); +} + +template <typename BaseList> +auto ListMixin<BaseList>::first() -> reference { + if (Base::empty()) + throw OutOfRangeException("first() called on empty list"); + return *Base::begin(); +} + +template <typename BaseList> +auto ListMixin<BaseList>::first() const -> const_reference { + if (Base::empty()) + throw OutOfRangeException("first() called on empty list"); + return *Base::begin(); +} + +template <typename BaseList> +auto ListMixin<BaseList>::last() -> reference { + if (Base::empty()) + throw OutOfRangeException("last() called on empty list"); + return *prev(Base::end()); +} + +template <typename BaseList> +auto ListMixin<BaseList>::last() const -> const_reference { + if (Base::empty()) + throw OutOfRangeException("last() called on empty list"); + return *prev(Base::end()); +} + +template <typename BaseList> +auto ListMixin<BaseList>::maybeFirst() -> Maybe<value_type> { + if (Base::empty()) + return {}; + return *Base::begin(); +} + +template <typename BaseList> +auto ListMixin<BaseList>::maybeLast() -> Maybe<value_type> { + if (Base::empty()) + return {}; + return *prev(Base::end()); +} + +template <typename BaseList> +void ListMixin<BaseList>::removeLast() { + if (Base::empty()) + throw OutOfRangeException("removeLast() called on empty list"); + Base::pop_back(); +} + +template <typename BaseList> +auto ListMixin<BaseList>::takeLast() -> value_type { + value_type e = std::move(last()); + Base::pop_back(); + return e; +} + +template <typename BaseList> +auto ListMixin<BaseList>::maybeTakeLast() -> Maybe<value_type> { + if (Base::empty()) + return {}; + value_type e = std::move(last()); + Base::pop_back(); + return e; +} + +template <typename BaseList> +void ListMixin<BaseList>::limitSizeBack(size_t maximumSize) { + while (Base::size() > maximumSize) + Base::pop_back(); +} + +template <typename BaseList> +size_t ListMixin<BaseList>::count() const { + return Base::size(); +} + +template <typename BaseList> +bool ListMixin<BaseList>::contains(const_reference e) const { + for (auto const& r : *this) { + if (r == e) + return true; + } + return false; +} + +template <typename BaseList> +size_t ListMixin<BaseList>::remove(const_reference e) { + size_t removed = 0; + auto i = Base::begin(); + while (i != Base::end()) { + if (*i == e) { + ++removed; + i = Base::erase(i); + } else { + ++i; + } + } + return removed; +} + +template <typename BaseList> +template <typename Filter> +void ListMixin<BaseList>::filter(Filter&& filter) { + Star::filter(*this, forward<Filter>(filter)); +} + +template <typename BaseList> +template <typename Comparator> +void ListMixin<BaseList>::insertSorted(value_type e, Comparator&& comparator) { + auto i = std::upper_bound(Base::begin(), Base::end(), e, forward<Comparator>(comparator)); + Base::insert(i, std::move(e)); +} + +template <typename BaseList> +void ListMixin<BaseList>::insertSorted(value_type e) { + auto i = std::upper_bound(Base::begin(), Base::end(), e); + Base::insert(i, std::move(e)); +} + +template <typename BaseList> +template <typename Comparator> +bool ListMixin<BaseList>::containsSorted(value_type const& e, Comparator&& comparator) { + auto range = std::equal_range(Base::begin(), Base::end(), e, forward<Comparator>(comparator)); + return range.first != range.second; +} + +template <typename BaseList> +bool ListMixin<BaseList>::containsSorted(value_type e) { + auto range = std::equal_range(Base::begin(), Base::end(), e); + return range.first != range.second; +} + +template <typename BaseList> +template <typename Function> +void ListMixin<BaseList>::exec(Function&& function) { + for (auto& e : *this) + function(e); +} + +template <typename BaseList> +template <typename Function> +void ListMixin<BaseList>::exec(Function&& function) const { + for (auto const& e : *this) + function(e); +} + +template <typename BaseList> +template <typename Function> +void ListMixin<BaseList>::transform(Function&& function) { + for (auto& e : *this) + e = function(e); +} + +template <typename BaseList> +template <typename Function> +bool ListMixin<BaseList>::any(Function&& function) const { + return Star::any(*this, forward<Function>(function)); +} + +template <typename BaseList> +bool ListMixin<BaseList>::any() const { + return Star::any(*this); +} + +template <typename BaseList> +template <typename Function> +bool ListMixin<BaseList>::all(Function&& function) const { + return Star::all(*this, forward<Function>(function)); +} + +template <typename BaseList> +bool ListMixin<BaseList>::all() const { + return Star::all(*this); +} + +template <typename BaseList> +template <typename Comparator> +void RandomAccessListMixin<BaseList>::sort(Comparator&& comparator) { + Star::sort(*this, forward<Comparator>(comparator)); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::sort() { + Star::sort(*this); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::reverse() { + Star::reverse(*this); +} + +template <typename BaseList> +size_t RandomAccessListMixin<BaseList>::indexOf(const_reference e, size_t from) const { + for (size_t i = from; i < Base::size(); ++i) + if (operator[](i) == e) + return i; + return NPos; +} + +template <typename BaseList> +size_t RandomAccessListMixin<BaseList>::lastIndexOf(const_reference e, size_t til) const { + size_t index = NPos; + size_t end = std::min(Base::size(), til); + for (size_t i = 0; i < end; ++i) { + if (operator[](i) == e) + index = i; + } + return index; +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::at(size_t n) const -> const_reference { + if (n >= Base::size()) + throw OutOfRangeException(strf("out of range list::at(%s)", n)); + return operator[](n); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::at(size_t n) -> reference { + if (n >= Base::size()) + throw OutOfRangeException(strf("out of range list::at(%s)", n)); + return operator[](n); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::operator[](size_t n) const -> const_reference { + starAssert(n < Base::size()); + return Base::operator[](n); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::operator[](size_t n) -> reference { + starAssert(n < Base::size()); + return Base::operator[](n); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::get(size_t n, value_type def) const -> value_type { + if (n >= BaseList::size()) + return def; + return operator[](n); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::takeAt(size_t i) -> value_type { + value_type e = at(i); + Base::erase(Base::begin() + i); + return e; +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::wrap(size_t n) const -> const_reference { + if (BaseList::empty()) + throw OutOfRangeException(); + else + return operator[](n % BaseList::size()); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::wrap(size_t n) -> reference { + if (BaseList::empty()) + throw OutOfRangeException(); + else + return operator[](n % BaseList::size()); +} + +template <typename BaseList> +auto RandomAccessListMixin<BaseList>::wrap(size_t n, value_type def) const -> value_type { + if (BaseList::empty()) + return def; + else + return operator[](n % BaseList::size()); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::eraseAt(size_t i) { + starAssert(i < Base::size()); + Base::erase(Base::begin() + i); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::eraseAt(size_t b, size_t e) { + starAssert(b < Base::size() && e <= Base::size()); + Base::erase(Base::begin() + b, Base::begin() + e); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::insertAt(size_t pos, value_type e) { + starAssert(pos <= Base::size()); + Base::insert(Base::begin() + pos, std::move(e)); +} + +template <typename BaseList> +template <typename Container> +void RandomAccessListMixin<BaseList>::insertAllAt(size_t pos, Container const& l) { + starAssert(pos <= Base::size()); + Base::insert(Base::begin() + pos, l.begin(), l.end()); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::set(size_t pos, value_type e) { + if (pos >= Base::size()) + Base::resize(pos + 1); + operator[](pos) = std::move(e); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::swap(size_t i, size_t j) { + std::swap(operator[](i), operator[](j)); +} + +template <typename BaseList> +void RandomAccessListMixin<BaseList>::move(size_t from, size_t to) { + Base::insert(to, takeAt(from)); +} + +template <typename BaseList> +void FrontModifyingListMixin<BaseList>::prepend(value_type e) { + Base::push_front(std::move(e)); +} + +template <typename BaseList> +template <typename Container> +void FrontModifyingListMixin<BaseList>::prependAll(Container&& list) { + for (auto i = std::rbegin(list); i != std::rend(list); ++i) { + if (std::is_rvalue_reference<Container&&>::value) + Base::push_front(std::move(*i)); + else + Base::push_front(*i); + } +} + +template <typename BaseList> +template <class... Args> +auto FrontModifyingListMixin<BaseList>::emplacePrepend(Args&&... args) -> reference { + Base::emplace_front(forward<Args>(args)...); + return *Base::begin(); +} + +template <typename BaseList> +void FrontModifyingListMixin<BaseList>::removeFirst() { + if (Base::empty()) + throw OutOfRangeException("removeFirst() called on empty list"); + Base::pop_front(); +} + +template <typename BaseList> +auto FrontModifyingListMixin<BaseList>::takeFirst() -> value_type { + value_type e = std::move(Base::first()); + Base::pop_front(); + return e; +} + +template <typename BaseList> +void FrontModifyingListMixin<BaseList>::limitSizeFront(size_t maximumSize) { + while (Base::size() > maximumSize) + Base::pop_front(); +} + +template <typename Element, typename Allocator> +template <typename Container> +List<Element, Allocator> List<Element, Allocator>::from(Container const& c) { + return List(c.begin(), c.end()); +} + +template <typename Element, typename Allocator> +auto List<Element, Allocator>::ptr() -> value_type * { + return Base::data(); +} + +template <typename Element, typename Allocator> +auto List<Element, Allocator>::ptr() const -> value_type const * { + return Base::data(); +} + +template <typename Element, typename Allocator> +auto List<Element, Allocator>::slice(SliceIndex a, SliceIndex b, int i) const -> List { + return Star::slice(*this, a, b, i); +} + +template <typename Element, typename Allocator> +template <typename Filter> +auto List<Element, Allocator>::filtered(Filter&& filter) const -> List { + List list(*this); + list.filter(forward<Filter>(filter)); + return list; +} + +template <typename Element, typename Allocator> +template <typename Comparator> +auto List<Element, Allocator>::sorted(Comparator&& comparator) const -> List { + List list(*this); + list.sort(forward<Comparator>(comparator)); + return list; +} + +template <typename Element, typename Allocator> +List<Element, Allocator> List<Element, Allocator>::sorted() const { + List list(*this); + list.sort(); + return list; +} + +template <typename Element, typename Allocator> +template <typename Function> +auto List<Element, Allocator>::transformed(Function&& function) { + List<typename std::decay<decltype(std::declval<Function>()(std::declval<reference>()))>::type> res; + res.reserve(Base::size()); + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, typename Allocator> +template <typename Function> +auto List<Element, Allocator>::transformed(Function&& function) const { + List<typename std::decay<decltype(std::declval<Function>()(std::declval<const_reference>()))>::type> res; + res.reserve(Base::size()); + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, size_t MaxSize> +template <typename Container> +StaticList<Element, MaxSize> StaticList<Element, MaxSize>::from(Container const& c) { + return StaticList(c.begin(), c.end()); +} + +template <typename Element, size_t MaxSize> +auto StaticList<Element, MaxSize>::slice(SliceIndex a, SliceIndex b, int i) const -> StaticList { + return Star::slice(*this, a, b, i); +} + +template <typename Element, size_t MaxSize> +template <typename Filter> +auto StaticList<Element, MaxSize>::filtered(Filter&& filter) const -> StaticList { + StaticList list(*this); + list.filter(forward<Filter>(filter)); + return list; +} + +template <typename Element, size_t MaxSize> +template <typename Comparator> +auto StaticList<Element, MaxSize>::sorted(Comparator&& comparator) const -> StaticList { + StaticList list(*this); + list.sort(forward<Comparator>(comparator)); + return list; +} + +template <typename Element, size_t MaxSize> +StaticList<Element, MaxSize> StaticList<Element, MaxSize>::sorted() const { + StaticList list(*this); + list.sort(); + return list; +} + +template <typename Element, size_t MaxSize> +template <typename Function> +auto StaticList<Element, MaxSize>::transformed(Function&& function) { + StaticList<typename std::decay<decltype(std::declval<Function>()(std::declval<reference>()))>::type, MaxSize> res; + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, size_t MaxSize> +template <typename Function> +auto StaticList<Element, MaxSize>::transformed(Function&& function) const { + StaticList<typename std::decay<decltype(std::declval<Function>()(std::declval<const_reference>()))>::type, MaxSize> res; + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, size_t MaxStackSize> +template <typename Container> +SmallList<Element, MaxStackSize> SmallList<Element, MaxStackSize>::from(Container const& c) { + return SmallList(c.begin(), c.end()); +} + +template <typename Element, size_t MaxStackSize> +auto SmallList<Element, MaxStackSize>::slice(SliceIndex a, SliceIndex b, int i) const -> SmallList { + return Star::slice(*this, a, b, i); +} + +template <typename Element, size_t MaxStackSize> +template <typename Filter> +auto SmallList<Element, MaxStackSize>::filtered(Filter&& filter) const -> SmallList { + SmallList list(*this); + list.filter(forward<Filter>(filter)); + return list; +} + +template <typename Element, size_t MaxStackSize> +template <typename Comparator> +auto SmallList<Element, MaxStackSize>::sorted(Comparator&& comparator) const -> SmallList { + SmallList list(*this); + list.sort(forward<Comparator>(comparator)); + return list; +} + +template <typename Element, size_t MaxStackSize> +SmallList<Element, MaxStackSize> SmallList<Element, MaxStackSize>::sorted() const { + SmallList list(*this); + list.sort(); + return list; +} + +template <typename Element, size_t MaxStackSize> +template <typename Function> +auto SmallList<Element, MaxStackSize>::transformed(Function&& function) { + SmallList<typename std::decay<decltype(std::declval<Function>()(std::declval<reference>()))>::type, MaxStackSize> res; + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, size_t MaxStackSize> +template <typename Function> +auto SmallList<Element, MaxStackSize>::transformed(Function&& function) const { + SmallList<typename std::decay<decltype(std::declval<Function>()(std::declval<const_reference>()))>::type, MaxStackSize> res; + transformInto(res, *this, forward<Function>(function)); + return res; +} + +template <typename Element, typename Allocator> +template <typename Container> +Deque<Element, Allocator> Deque<Element, Allocator>::from(Container const& c) { + return Deque(c.begin(), c.end()); +} + +template <typename Element, typename Allocator> +Deque<Element, Allocator> Deque<Element, Allocator>::slice(SliceIndex a, SliceIndex b, int i) const { + return Star::slice(*this, a, b, i); +} + +template <typename Element, typename Allocator> +template <typename Filter> +Deque<Element, Allocator> Deque<Element, Allocator>::filtered(Filter&& filter) const { + Deque l(*this); + l.filter(forward<Filter>(filter)); + return l; +} + +template <typename Element, typename Allocator> +template <typename Comparator> +Deque<Element, Allocator> Deque<Element, Allocator>::sorted(Comparator&& comparator) const { + Deque l(*this); + l.sort(forward<Comparator>(comparator)); + return l; +} + +template <typename Element, typename Allocator> +Deque<Element, Allocator> Deque<Element, Allocator>::sorted() const { + Deque l(*this); + l.sort(); + return l; +} + +template <typename Element, typename Allocator> +template <typename Function> +auto Deque<Element, Allocator>::transformed(Function&& function) { + return Star::transform<Deque<decltype(std::declval<Function>()(std::declval<reference>()))>>(*this, forward<Function>(function)); +} + +template <typename Element, typename Allocator> +template <typename Function> +auto Deque<Element, Allocator>::transformed(Function&& function) const { + return Star::transform<Deque<decltype(std::declval<Function>()(std::declval<const_reference>()))>>(*this, forward<Function>(function)); +} + +template <typename Element, typename Allocator> +template <typename Container> +LinkedList<Element, Allocator> LinkedList<Element, Allocator>::from(Container const& c) { + return LinkedList(c.begin(), c.end()); +} + +template <typename Element, typename Allocator> +void LinkedList<Element, Allocator>::appendAll(LinkedList list) { + Base::splice(Base::end(), list); +} + +template <typename Element, typename Allocator> +void LinkedList<Element, Allocator>::prependAll(LinkedList list) { + Base::splice(Base::begin(), list); +} + +template <typename Element, typename Allocator> +template <typename Container> +void LinkedList<Element, Allocator>::appendAll(Container&& list) { + for (auto& e : list) { + if (std::is_rvalue_reference<Container&&>::value) + Base::push_back(std::move(e)); + else + Base::push_back(e); + } +} + +template <typename Element, typename Allocator> +template <typename Container> +void LinkedList<Element, Allocator>::prependAll(Container&& list) { + for (auto i = std::rbegin(list); i != std::rend(list); ++i) { + if (std::is_rvalue_reference<Container&&>::value) + Base::push_front(std::move(*i)); + else + Base::push_front(*i); + } +} + +template <typename Element, typename Allocator> +template <typename Filter> +LinkedList<Element, Allocator> LinkedList<Element, Allocator>::filtered(Filter&& filter) const { + LinkedList list(*this); + list.filter(forward<Filter>(filter)); + return list; +} + +template <typename Element, typename Allocator> +template <typename Comparator> +LinkedList<Element, Allocator> LinkedList<Element, Allocator>::sorted(Comparator&& comparator) const { + LinkedList l(*this); + l.sort(forward<Comparator>(comparator)); + return l; +} + +template <typename Element, typename Allocator> +LinkedList<Element, Allocator> LinkedList<Element, Allocator>::sorted() const { + LinkedList l(*this); + l.sort(); + return l; +} + +template <typename Element, typename Allocator> +template <typename Function> +auto LinkedList<Element, Allocator>::transformed(Function&& function) { + return Star::transform<LinkedList<decltype(std::declval<Function>()(std::declval<reference>()))>>(*this, forward<Function>(function)); +} + +template <typename Element, typename Allocator> +template <typename Function> +auto LinkedList<Element, Allocator>::transformed(Function&& function) const { + return Star::transform<LinkedList<decltype(std::declval<Function>()(std::declval<const_reference>()))>>(*this, forward<Function>(function)); +} + +template <typename BaseList> +std::ostream& operator<<(std::ostream& os, ListMixin<BaseList> const& list) { + os << "("; + for (auto i = list.begin(); i != list.end(); ++i) { + if (i != list.begin()) + os << ", "; + os << *i; + } + os << ")"; + return os; +} + +template <typename List> +size_t ListHasher<List>::operator()(List const& l) const { + size_t h = 0; + for (auto const& e : l) + hashCombine(h, elemHasher(e)); + return h; +} + +template <typename... Containers> +typename ListZipTypes<Containers...>::Result zip(Containers&&... args) { + typename ListZipTypes<Containers...>::Result res; + for (auto el : zipIterator(args...)) + res.push_back(std::move(el)); + + return res; +} + +template <typename Container> +typename ListEnumerateTypes<Container>::Result enumerate(Container&& container) { + typename ListEnumerateTypes<Container>::Result res; + for (auto el : enumerateIterator(container)) + res.push_back(std::move(el)); + + return res; +} + +} + +#endif diff --git a/source/core/StarListener.cpp b/source/core/StarListener.cpp new file mode 100644 index 0000000..642f2a2 --- /dev/null +++ b/source/core/StarListener.cpp @@ -0,0 +1,49 @@ +#include "StarListener.hpp" + +namespace Star { + +Listener::~Listener() {} + +CallbackListener::CallbackListener(function<void()> callback) + : callback(move(callback)) {} + +void CallbackListener::trigger() { + if (callback) + callback(); +} + +TrackerListener::TrackerListener() : triggered(false) {} + +void ListenerGroup::addListener(ListenerWeakPtr listener) { + MutexLocker locker(m_mutex); + m_listeners.insert(move(listener)); +} + +void ListenerGroup::removeListener(ListenerWeakPtr listener) { + MutexLocker locker(m_mutex); + m_listeners.erase(move(listener)); +} + +void ListenerGroup::clearExpiredListeners() { + MutexLocker locker(m_mutex); + eraseWhere(m_listeners, mem_fn(&ListenerWeakPtr::expired)); +}; + +void ListenerGroup::clearAllListeners() { + MutexLocker locker(m_mutex); + m_listeners.clear(); +} + +void ListenerGroup::trigger() { + MutexLocker locker(m_mutex); + filter(m_listeners, [](ListenerWeakPtr const& wl) { + if (auto lock = wl.lock()) { + lock->trigger(); + return true; + } else { + return false; + } + }); +} + +} diff --git a/source/core/StarListener.hpp b/source/core/StarListener.hpp new file mode 100644 index 0000000..51a8237 --- /dev/null +++ b/source/core/StarListener.hpp @@ -0,0 +1,67 @@ +#ifndef STAR_LISTENER_HPP +#define STAR_LISTENER_HPP + +#include "StarThread.hpp" + +namespace Star { + +STAR_CLASS(Listener); +STAR_CLASS(CallbackListener); +STAR_CLASS(TrackerListener); +STAR_CLASS(ListenerGroup); + +class Listener { +public: + virtual ~Listener(); + virtual void trigger() = 0; +}; + +class CallbackListener : public Listener { +public: + CallbackListener(function<void()> callback); + +protected: + virtual void trigger() override; + +private: + function<void()> callback; +}; + +class TrackerListener : public Listener { +public: + TrackerListener(); + + bool pullTriggered(); + +protected: + virtual void trigger() override; + +private: + atomic<bool> triggered; +}; + +class ListenerGroup { +public: + void addListener(ListenerWeakPtr listener); + void removeListener(ListenerWeakPtr listener); + void clearExpiredListeners(); + void clearAllListeners(); + + void trigger(); + +private: + Mutex m_mutex; + std::set<ListenerWeakPtr, std::owner_less<ListenerWeakPtr>> m_listeners; +}; + +inline bool TrackerListener::pullTriggered() { + return triggered.exchange(false); +} + +inline void TrackerListener::trigger() { + triggered = true; +} + +} + +#endif diff --git a/source/core/StarLockFile.hpp b/source/core/StarLockFile.hpp new file mode 100644 index 0000000..af89c7f --- /dev/null +++ b/source/core/StarLockFile.hpp @@ -0,0 +1,42 @@ +#ifndef STAR_LOCK_FILE_HPP +#define STAR_LOCK_FILE_HPP + +#include "StarMaybe.hpp" +#include "StarString.hpp" + +namespace Star { + +class LockFile { +public: + // Convenience function, tries to acquire a lock, and if succesfull returns an + // already locked + // LockFile. + static Maybe<LockFile> acquireLock(String const& filename, int64_t lockTimeout = 1000); + + LockFile(String const& filename); + LockFile(LockFile&& lockFile); + // Automatically unlocks. + ~LockFile(); + + LockFile(LockFile const&) = delete; + LockFile& operator=(LockFile const&) = delete; + + LockFile& operator=(LockFile&& lockFile); + + // Wait at most timeout time to acquire the file lock, and return true if the + // lock was acquired. If timeout is negative, wait forever. + bool lock(int64_t timeout = 0); + void unlock(); + + bool isLocked() const; + +private: + static int64_t const MaximumSleepMillis = 25; + + String m_filename; + shared_ptr<void> m_handle; +}; + +} + +#endif diff --git a/source/core/StarLockFile_unix.cpp b/source/core/StarLockFile_unix.cpp new file mode 100644 index 0000000..2b6e0bf --- /dev/null +++ b/source/core/StarLockFile_unix.cpp @@ -0,0 +1,94 @@ +#include "StarLockFile.hpp" +#include "StarTime.hpp" +#include "StarThread.hpp" + +#include <sys/file.h> +#include <fcntl.h> +#include <errno.h> +#include <unistd.h> + +namespace Star { + +int64_t const LockFile::MaximumSleepMillis; + +Maybe<LockFile> LockFile::acquireLock(String const& filename, int64_t lockTimeout) { + LockFile lock(move(filename)); + if (lock.lock(lockTimeout)) + return move(lock); + return {}; +} + +LockFile::LockFile(String const& filename) : m_filename(move(filename)) {} + +LockFile::LockFile(LockFile&& lockFile) { + operator=(move(lockFile)); +} + +LockFile::~LockFile() { + unlock(); +} + +LockFile& LockFile::operator=(LockFile&& lockFile) { + m_filename = move(lockFile.m_filename); + m_handle = move(lockFile.m_handle); + + return *this; +} + +bool LockFile::lock(int64_t timeout) { + auto doFLock = [](String const& filename, bool block) -> shared_ptr<int> { + int fd = open(filename.utf8Ptr(), O_RDONLY | O_CREAT, 0644); + if (fd < 0) + throw StarException(strf("Could not open lock file %s, %s\n", filename, strerror(errno))); + + int ret; + if (block) + ret = flock(fd, LOCK_EX); + else + ret = flock(fd, LOCK_EX | LOCK_NB); + + if (ret != 0) { + close(fd); + if (errno != EWOULDBLOCK) + throw StarException(strf("Could not lock file %s, %s\n", filename, strerror(errno))); + return {}; + } + + return make_shared<int>(fd); + }; + + if (timeout < 0) { + m_handle = doFLock(m_filename, true); + return true; + } else if (timeout == 0) { + m_handle = doFLock(m_filename, false); + return (bool)m_handle; + } else { + int64_t startTime = Time::monotonicMilliseconds(); + while (true) { + m_handle = doFLock(m_filename, false); + if (m_handle) + return true; + + if (Time::monotonicMilliseconds() - startTime > timeout) + return false; + + Thread::sleep(min(timeout / 4, MaximumSleepMillis)); + } + } +} + +void LockFile::unlock() { + if (m_handle) { + int fd = *(int*)m_handle.get(); + unlink(m_filename.utf8Ptr()); + close(fd); + m_handle.reset(); + } +} + +bool LockFile::isLocked() const { + return (bool)m_handle; +} + +} diff --git a/source/core/StarLockFile_windows.cpp b/source/core/StarLockFile_windows.cpp new file mode 100644 index 0000000..1cf1473 --- /dev/null +++ b/source/core/StarLockFile_windows.cpp @@ -0,0 +1,80 @@ +#include "StarLockFile.hpp" +#include "StarTime.hpp" +#include "StarThread.hpp" + +#include "StarString_windows.hpp" + +#include <windows.h> + +namespace Star { + +int64_t const LockFile::MaximumSleepMillis; + +Maybe<LockFile> LockFile::acquireLock(String const& filename, int64_t lockTimeout) { + LockFile lock(move(filename)); + if (lock.lock(lockTimeout)) + return move(lock); + return {}; +} + +LockFile::LockFile(String const& filename) : m_filename(move(filename)) {} + +LockFile::LockFile(LockFile&& lockFile) { + operator=(move(lockFile)); +} + +LockFile::~LockFile() { + unlock(); +} + +LockFile& LockFile::operator=(LockFile&& lockFile) { + m_filename = move(lockFile.m_filename); + m_handle = move(lockFile.m_handle); + + return *this; +} + +bool LockFile::lock(int64_t timeout) { + auto doFLock = [](String const& filename) -> shared_ptr<HANDLE> { + HANDLE handle = CreateFileW( + stringToUtf16(filename).get(), GENERIC_READ, 0, nullptr, OPEN_ALWAYS, FILE_FLAG_DELETE_ON_CLOSE, nullptr); + if (handle == INVALID_HANDLE_VALUE) { + if (GetLastError() == ERROR_SHARING_VIOLATION) + return {}; + throw StarException(strf("Could not open lock file %s, error code %s\n", filename, GetLastError())); + } + + return make_shared<HANDLE>(handle); + }; + + if (timeout == 0) { + m_handle = doFLock(m_filename); + return (bool)m_handle; + } else { + int64_t startTime = Time::monotonicMilliseconds(); + while (true) { + m_handle = doFLock(m_filename); + if (m_handle) + return true; + + if (timeout > 0 && Time::monotonicMilliseconds() - startTime > timeout) + return false; + + Thread::sleep(min(timeout / 4, MaximumSleepMillis)); + } + } +} + +void LockFile::unlock() { + if (m_handle) { + HANDLE handle = *(HANDLE*)m_handle.get(); + CloseHandle(handle); + m_handle.reset(); + } +} + +bool LockFile::isLocked() const { + return (bool)m_handle; +} + +} diff --git a/source/core/StarLogging.cpp b/source/core/StarLogging.cpp new file mode 100644 index 0000000..2a7f88e --- /dev/null +++ b/source/core/StarLogging.cpp @@ -0,0 +1,192 @@ +#include "StarLogging.hpp" + +namespace Star { + +EnumMap<LogLevel> const LogLevelNames{ + {LogLevel::Debug, "Debug"}, + {LogLevel::Info, "Info"}, + {LogLevel::Warn, "Warn"}, + {LogLevel::Error, "Error"} +}; + +LogSink::LogSink() + : m_level(LogLevel::Info) {} + +LogSink::~LogSink() {} + +void LogSink::setLevel(LogLevel level) { + m_level = level; +} + +LogLevel LogSink::level() { + return m_level; +} + +void StdoutLogSink::log(char const* msg, LogLevel level) { + MutexLocker locker(m_logMutex); + coutf("[%s] %s\n", LogLevelNames.getRight(level), msg); +} + +FileLogSink::FileLogSink(String const& filename, LogLevel level, bool truncate) { + if (truncate) + m_output = File::open(filename, IOMode::Write | IOMode::Append | IOMode::Truncate); + else + m_output = File::open(filename, IOMode::Write | IOMode::Append); + setLevel(level); +} + +void FileLogSink::log(char const* msg, LogLevel level) { + MutexLocker locker(m_logMutex); + auto line = strf("[%s] [%s] %s\n", Time::printCurrentDateAndTime("<hours>:<minutes>:<seconds>.<millis>"), LogLevelNames.getRight(level), msg); + m_output->write(line.data(), line.size()); +} + +void Logger::addSink(LogSinkPtr s) { + MutexLocker locker(s_mutex); + s_sinks.insert(s); +} + +void Logger::removeSink(LogSinkPtr s) { + MutexLocker locker(s_mutex); + s_sinks.erase(s); +} + +LogSinkPtr Logger::stdoutSink() { + MutexLocker locker(s_mutex); + return s_stdoutSink; +} + +void Logger::removeStdoutSink() { + MutexLocker locker(s_mutex); + s_sinks.erase(s_stdoutSink); +} + +void Logger::log(LogLevel level, char const* msg) { + MutexLocker locker(s_mutex); + + for (auto const& l : s_sinks) { + if (l->level() <= level) + l->log(msg, level); + } +} + +shared_ptr<StdoutLogSink> Logger::s_stdoutSink = make_shared<StdoutLogSink>(); +HashSet<LogSinkPtr> Logger::s_sinks{s_stdoutSink}; +Mutex Logger::s_mutex; + +String LogMap::getValue(String const& key) { + MutexLocker locker(s_logMapMutex); + return s_logMap.value(key); +} + +void LogMap::setValue(String const& key, String const& value) { + MutexLocker locker(s_logMapMutex); + s_logMap[key] = value; +} + +Map<String, String> LogMap::getValues() { + MutexLocker locker(s_logMapMutex); + return Map<String, String>::from(s_logMap); +} + +void LogMap::clear() { + MutexLocker locker(s_logMapMutex); + s_logMap.clear(); +} + +HashMap<String, String> LogMap::s_logMap; +Mutex LogMap::s_logMapMutex; + +size_t const SpatialLogger::MaximumLines; +size_t const SpatialLogger::MaximumPoints; +size_t const SpatialLogger::MaximumText; + +void SpatialLogger::logPoly(char const* space, PolyF const& poly, Vec4B const& color) { + MutexLocker locker(s_mutex); + auto& lines = s_lines[space]; + + for (size_t i = 0; i < poly.sides(); ++i) { + auto side = poly.side(i); + lines.append(Line{side.min(), side.max(), color}); + } + + while (lines.size() > MaximumLines) + lines.removeFirst(); +} + +void SpatialLogger::logLine(char const* space, Line2F const& line, Vec4B const& color) { + MutexLocker locker(s_mutex); + auto& lines = s_lines[space]; + + lines.append(Line{line.min(), line.max(), color}); + + while (lines.size() > MaximumLines) + lines.removeFirst(); +} + +void SpatialLogger::logLine(char const* space, Vec2F const& begin, Vec2F const& end, Vec4B const& color) { + MutexLocker locker(s_mutex); + auto& lines = s_lines[space]; + + lines.append(Line{begin, end, color}); + + while (lines.size() > MaximumLines) + lines.removeFirst(); +} + +void SpatialLogger::logPoint(char const* space, Vec2F const& position, Vec4B const& color) { + MutexLocker locker(s_mutex); + auto& points = s_points[space]; + + points.append(Point{position, color}); + + while (points.size() > MaximumPoints) + points.removeFirst(); +} + +void SpatialLogger::logText(char const* space, String text, Vec2F const& position, Vec4B const& color) { + MutexLocker locker(s_mutex); + auto& texts = s_logText[space]; + + texts.append(LogText{text, position, color}); + + while (texts.size() > MaximumText) + texts.removeFirst(); +} + +Deque<SpatialLogger::Line> SpatialLogger::getLines(char const* space, bool andClear) { + MutexLocker locker(s_mutex); + if (andClear) + return take(s_lines[space]); + else + return s_lines[space]; +} + +Deque<SpatialLogger::Point> SpatialLogger::getPoints(char const* space, bool andClear) { + MutexLocker locker(s_mutex); + if (andClear) + return take(s_points[space]); + else + return s_points[space]; +} + +Deque<SpatialLogger::LogText> SpatialLogger::getText(char const* space, bool andClear) { + MutexLocker locker(s_mutex); + if (andClear) + return take(s_logText[space]); + else + return s_logText[space]; +} + +void SpatialLogger::clear() { + MutexLocker locker(s_mutex); + s_lines.clear(); + s_points.clear(); + s_logText.clear(); +} + +Mutex SpatialLogger::s_mutex; +StringMap<Deque<SpatialLogger::Line>> SpatialLogger::s_lines; +StringMap<Deque<SpatialLogger::Point>> SpatialLogger::s_points; +StringMap<Deque<SpatialLogger::LogText>> SpatialLogger::s_logText; +} diff --git a/source/core/StarLogging.hpp b/source/core/StarLogging.hpp new file mode 100644 index 0000000..d90e0e1 --- /dev/null +++ b/source/core/StarLogging.hpp @@ -0,0 +1,193 @@ +#ifndef STAR_LOGGING_HPP +#define STAR_LOGGING_HPP + +#include "StarThread.hpp" +#include "StarSet.hpp" +#include "StarString.hpp" +#include "StarPoly.hpp" +#include "StarBiMap.hpp" +#include "StarTime.hpp" +#include "StarFile.hpp" + +namespace Star { + +enum class LogLevel { + Debug, + Info, + Warn, + Error +}; +extern EnumMap<LogLevel> const LogLevelNames; + +STAR_CLASS(LogSink); + +// A sink for Logger messages. +class LogSink { +public: + LogSink(); + virtual ~LogSink(); + + virtual void log(char const* msg, LogLevel level) = 0; + + void setLevel(LogLevel level); + LogLevel level(); + +private: + atomic<LogLevel> m_level; +}; + +class StdoutLogSink : public LogSink { +public: + virtual void log(char const* msg, LogLevel level); + +private: + Mutex m_logMutex; +}; + +class FileLogSink : public LogSink { +public: + FileLogSink(String const& filename, LogLevel level, bool truncate); + + virtual void log(char const* msg, LogLevel level); + +private: + FilePtr m_output; + Mutex m_logMutex; +}; + +// A basic loging system that logs to multiple streams. Can log at Debug, +// Info, Warn, and Error logging levels. By default logs to stdout. +class Logger { +public: + static void addSink(LogSinkPtr s); + static void removeSink(LogSinkPtr s); + + // Default LogSink that outputs to stdout. + static LogSinkPtr stdoutSink(); + // Don't use the stdout sink. + static void removeStdoutSink(); + + static void log(LogLevel level, char const* msg); + + template <typename... Args> + static void logf(LogLevel level, char const* msg, Args const&... args); + + template <typename... Args> + static void debug(char const* msg, Args const&... args); + template <typename... Args> + static void info(char const* msg, Args const&... args); + template <typename... Args> + static void warn(char const* msg, Args const&... args); + template <typename... Args> + static void error(char const* msg, Args const&... args); + +private: + static shared_ptr<StdoutLogSink> s_stdoutSink; + static HashSet<LogSinkPtr> s_sinks; + static Mutex s_mutex; +}; + +// For logging data that is very high frequency. It is a map of debug values to +// be displayed every frame, or in a debug output window, etc. +class LogMap { +public: + static String getValue(String const& key); + static void setValue(String const& key, String const& value); + + // Shorthand, converts given type to string using std::ostream. + template <typename T> + static void set(String const& key, T const& t); + + static Map<String, String> getValues(); + static void clear(); + +private: + static HashMap<String, String> s_logMap; + static Mutex s_logMapMutex; +}; + +// Logging for spatial data. Divided into multiple named coordinate spaces. +class SpatialLogger { +public: + // Maximum count of objects stored per space + static size_t const MaximumLines = 200000; + static size_t const MaximumPoints = 200000; + static size_t const MaximumText = 10000; + + struct Line { + Vec2F begin; + Vec2F end; + Vec4B color; + }; + + struct Point { + Vec2F position; + Vec4B color; + }; + + struct LogText { + String text; + Vec2F position; + Vec4B color; + }; + + static void logPoly(char const* space, PolyF const& poly, Vec4B const& color); + static void logLine(char const* space, Line2F const& line, Vec4B const& color); + static void logLine(char const* space, Vec2F const& begin, Vec2F const& end, Vec4B const& color); + static void logPoint(char const* space, Vec2F const& position, Vec4B const& color); + static void logText(char const* space, String text, Vec2F const& position, Vec4B const& color); + + static Deque<Line> getLines(char const* space, bool andClear); + static Deque<Point> getPoints(char const* space, bool andClear); + static Deque<LogText> getText(char const* space, bool andClear); + + static void clear(); + +private: + static Mutex s_mutex; + static StringMap<Deque<Line>> s_lines; + static StringMap<Deque<Point>> s_points; + static StringMap<Deque<LogText>> s_logText; +}; + +template <typename... Args> +void Logger::logf(LogLevel level, char const* msg, Args const&... args) { + MutexLocker locker(s_mutex); + Maybe<std::string> output; + for (auto const& l : s_sinks) { + if (l->level() <= level) { + if (!output) + output = strf(msg, args...); + l->log(output->c_str(), level); + } + } +} + +template <typename... Args> +void Logger::debug(char const* msg, Args const&... args) { + logf(LogLevel::Debug, msg, args...); +} + +template <typename... Args> +void Logger::info(char const* msg, Args const&... args) { + logf(LogLevel::Info, msg, args...); +} + +template <typename... Args> +void Logger::warn(char const* msg, Args const&... args) { + logf(LogLevel::Warn, msg, args...); +} + +template <typename... Args> +void Logger::error(char const* msg, Args const&... args) { + logf(LogLevel::Error, msg, args...); +} + +template <typename T> +void LogMap::set(String const& key, T const& t) { + setValue(key, strf("%s", t)); +} + +} + +#endif diff --git a/source/core/StarLruCache.hpp b/source/core/StarLruCache.hpp new file mode 100644 index 0000000..d047a37 --- /dev/null +++ b/source/core/StarLruCache.hpp @@ -0,0 +1,149 @@ +#ifndef STAR_LRU_CACHE_HPP +#define STAR_LRU_CACHE_HPP + +#include "StarOrderedMap.hpp" +#include "StarBlockAllocator.hpp" + +namespace Star { + +template <typename OrderedMapType> +class LruCacheBase { +public: + typedef typename OrderedMapType::key_type Key; + typedef typename OrderedMapType::mapped_type Value; + + typedef function<Value(Key const&)> ProducerFunction; + + LruCacheBase(size_t maxSize = 256); + + // Max size cannot be zero, it will be clamped to at least 1 in order to hold + // the most recent element returned by get. + size_t maxSize() const; + void setMaxSize(size_t maxSize); + + size_t currentSize() const; + + List<Key> keys() const; + List<Value> values() const; + + // If the value is in the cache, returns a pointer to it and marks it as + // accessed, otherwise returns nullptr. + Value* ptr(Key const& key); + + // Put the given value into the cache. + void set(Key const& key, Value value); + // Removes the given value from the cache. If found and removed, returns + // true. + bool remove(Key const& key); + + // Remove all key / value pairs matching a filter. + void removeWhere(function<bool(Key const&, Value&)> filter); + + // If the value for the key is not found in the cache, produce it with the + // given producer. Producer shold take the key as an argument and return the + // value. + template <typename Producer> + Value& get(Key const& key, Producer producer); + + // Clear all cached entries. + void clear(); + +private: + OrderedMapType m_map; + size_t m_maxSize; +}; + +template <typename Key, typename Value, typename Compare = std::less<Key>, typename Allocator = BlockAllocator<pair<Key const, Value>, 1024>> +using LruCache = LruCacheBase<OrderedMap<Key, Value, Compare, Allocator>>; + +template <typename Key, typename Value, typename Hash = Star::hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = BlockAllocator<pair<Key const, Value>, 1024>> +using HashLruCache = LruCacheBase<OrderedHashMap<Key, Value, Hash, Equals, Allocator>>; + +template <typename OrderedMapType> +LruCacheBase<OrderedMapType>::LruCacheBase(size_t maxSize) { + setMaxSize(maxSize); +} + +template <typename OrderedMapType> +size_t LruCacheBase<OrderedMapType>::maxSize() const { + return m_maxSize; +} + +template <typename OrderedMapType> +void LruCacheBase<OrderedMapType>::setMaxSize(size_t maxSize) { + m_maxSize = max<size_t>(maxSize, 1); + + while (m_map.size() > m_maxSize) + m_map.removeFirst(); +} + +template <typename OrderedMapType> +size_t LruCacheBase<OrderedMapType>::currentSize() const { + return m_map.size(); +} + +template <typename OrderedMapType> +auto LruCacheBase<OrderedMapType>::keys() const -> List<Key> { + return m_map.keys(); +} + +template <typename OrderedMapType> +auto LruCacheBase<OrderedMapType>::values() const -> List<Value> { + return m_map.values(); +} + +template <typename OrderedMapType> +auto LruCacheBase<OrderedMapType>::ptr(Key const& key) -> Value * { + auto i = m_map.find(key); + if (i == m_map.end()) + return nullptr; + i = m_map.toBack(i); + return &i->second; +} + +template <typename OrderedMapType> +void LruCacheBase<OrderedMapType>::set(Key const& key, Value value) { + auto i = m_map.find(key); + if (i == m_map.end()) { + m_map.add(key, move(value)); + } else { + i->second = move(value); + m_map.toBack(i); + } +} + +template <typename OrderedMapType> +bool LruCacheBase<OrderedMapType>::remove(Key const& key) { + return m_map.remove(key); +} + +template <typename OrderedMapType> +void LruCacheBase<OrderedMapType>::removeWhere(function<bool(Key const&, Value&)> filter) { + eraseWhere(m_map, [&filter](auto& p) { + return filter(p.first, p.second); + }); +} + +template <typename OrderedMapType> +template <typename Producer> +auto LruCacheBase<OrderedMapType>::get(Key const& key, Producer producer) -> Value & { + while (m_map.size() > m_maxSize - 1) + m_map.removeFirst(); + + auto i = m_map.find(key); + if (i == m_map.end()) + i = m_map.insert({key, producer(key)}).first; + else + i = m_map.toBack(i); + + return i->second; +} + +template <typename OrderedMapType> +void LruCacheBase<OrderedMapType>::clear() { + m_map.clear(); +} + +} + +#endif diff --git a/source/core/StarLua.cpp b/source/core/StarLua.cpp new file mode 100644 index 0000000..b297820 --- /dev/null +++ b/source/core/StarLua.cpp @@ -0,0 +1,1459 @@ +#include "StarLua.hpp" +#include "StarArray.hpp" +#include "StarTime.hpp" + +namespace Star { + +std::ostream& operator<<(std::ostream& os, LuaValue const& value) { + if (value.is<LuaBoolean>()) { + os << (value.get<LuaBoolean>() ? "true" : "false"); + } else if (value.is<LuaInt>()) { + os << value.get<LuaInt>(); + } else if (value.is<LuaFloat>()) { + os << value.get<LuaFloat>(); + } else if (value.is<LuaString>()) { + os << value.get<LuaString>().ptr(); + } else if (value.is<LuaTable>()) { + os << "{"; + bool first = true; + value.get<LuaTable>().iterate([&os, &first](LuaValue const& key, LuaValue const& value) { + if (first) + first = false; + else + os << ", "; + os << key << ": " << value; + }); + os << "}"; + } else if (value.is<LuaFunction>()) { + os << "<function reg:" << value.get<LuaFunction>().handleIndex() << ">"; + } else if (value.is<LuaThread>()) { + os << "<thread reg:" << value.get<LuaThread>().handleIndex() << ">"; + } else if (value.is<LuaUserData>()) { + os << "<userdata reg:" << value.get<LuaUserData>().handleIndex() << ">"; + } else { + os << "nil"; + } + return os; +} + +bool LuaTable::contains(char const* key) const { + return engine().tableGet(false, handleIndex(), key) != LuaNil; +} + +void LuaTable::remove(char const* key) const { + engine().tableSet(false, handleIndex(), key, LuaNil); +} + +LuaInt LuaTable::length() const { + return engine().tableLength(false, handleIndex()); +} + +Maybe<LuaTable> LuaTable::getMetatable() const { + return engine().tableGetMetatable(handleIndex()); +} + +void LuaTable::setMetatable(LuaTable const& table) const { + return engine().tableSetMetatable(handleIndex(), table); +} + +LuaInt LuaTable::rawLength() const { + return engine().tableLength(true, handleIndex()); +} + +LuaCallbacks& LuaCallbacks::merge(LuaCallbacks const& callbacks) { + try { + for (auto const& pair : callbacks.m_callbacks) + m_callbacks.add(pair.first, pair.second); + } catch (MapException const& e) { + throw LuaException(strf("Failed to merge LuaCallbacks: %s", outputException(e, true))); + } + + return *this; +} + +StringMap<LuaDetail::LuaWrappedFunction> const& LuaCallbacks::callbacks() const { + return m_callbacks; +} + +bool LuaContext::containsPath(String path) const { + return engine().contextGetPath(handleIndex(), move(path)) != LuaNil; +} + +void LuaContext::load(char const* contents, size_t size, char const* name) { + engine().contextLoad(handleIndex(), contents, size, name); +} + +void LuaContext::load(String const& contents, String const& name) { + load(contents.utf8Ptr(), contents.utf8Size(), name.utf8Ptr()); +} + +void LuaContext::load(ByteArray const& contents, String const& name) { + load(contents.ptr(), contents.size(), name.utf8Ptr()); +} + +void LuaContext::setRequireFunction(RequireFunction requireFunction) { + engine().setContextRequire(handleIndex(), move(requireFunction)); +} + +void LuaContext::setCallbacks(String const& tableName, LuaCallbacks const& callbacks) const { + auto& eng = engine(); + auto callbackTable = eng.createTable(); + for (auto const& p : callbacks.callbacks()) + callbackTable.set(p.first, eng.createWrappedFunction(p.second)); + LuaContext::set(tableName, callbackTable); +} + +LuaString LuaContext::createString(String const& str) { + return engine().createString(str); +} + +LuaString LuaContext::createString(char const* str) { + return engine().createString(str); +} + +LuaTable LuaContext::createTable() { + return engine().createTable(); +} + +LuaValue LuaConverter<Json>::from(LuaEngine& engine, Json const& v) { + if (v.isType(Json::Type::Null)) { + return LuaNil; + } else if (v.isType(Json::Type::Float)) { + return LuaFloat(v.toDouble()); + } else if (v.isType(Json::Type::Bool)) { + return v.toBool(); + } else if (v.isType(Json::Type::Int)) { + return LuaInt(v.toInt()); + } else if (v.isType(Json::Type::String)) { + return engine.createString(v.stringPtr()->utf8Ptr()); + } else { + return LuaDetail::jsonContainerToTable(engine, v); + } +} + +Maybe<Json> LuaConverter<Json>::to(LuaEngine&, LuaValue const& v) { + if (v == LuaNil) + return Json(); + + if (auto b = v.ptr<LuaBoolean>()) + return Json(*b); + + if (auto i = v.ptr<LuaInt>()) + return Json(*i); + + if (auto f = v.ptr<LuaFloat>()) + return Json(*f); + + if (auto s = v.ptr<LuaString>()) + return Json(s->ptr()); + + if (v.is<LuaTable>()) + return LuaDetail::tableToJsonContainer(v.get<LuaTable>()); + + return {}; +} + +LuaValue LuaConverter<JsonObject>::from(LuaEngine& engine, JsonObject v) { + return engine.luaFrom<Json>(Json(move(v))); +} + +Maybe<JsonObject> LuaConverter<JsonObject>::to(LuaEngine& engine, LuaValue v) { + auto j = engine.luaTo<Json>(move(v)); + if (j.type() == Json::Type::Object) { + return j.toObject(); + } else if (j.type() == Json::Type::Array) { + auto list = j.arrayPtr(); + if (list->empty()) + return JsonObject(); + } + + return {}; +} + +LuaValue LuaConverter<JsonArray>::from(LuaEngine& engine, JsonArray v) { + return engine.luaFrom<Json>(Json(move(v))); +} + +Maybe<JsonArray> LuaConverter<JsonArray>::to(LuaEngine& engine, LuaValue v) { + auto j = engine.luaTo<Json>(move(v)); + if (j.type() == Json::Type::Array) { + return j.toArray(); + } else if (j.type() == Json::Type::Object) { + auto map = j.objectPtr(); + if (map->empty()) + return JsonArray(); + } + + return {}; +} + +LuaEnginePtr LuaEngine::create(bool safe) { + LuaEnginePtr self(new LuaEngine); + + self->m_state = lua_newstate(allocate, nullptr); + + self->m_scriptDefaultEnvRegistryId = LUA_NOREF; + self->m_wrappedFunctionMetatableRegistryId = LUA_NOREF; + self->m_requireFunctionMetatableRegistryId = LUA_NOREF; + + self->m_instructionLimit = 0; + self->m_profilingEnabled = false; + self->m_instructionMeasureInterval = 1000; + self->m_instructionCount = 0; + self->m_recursionLevel = 0; + self->m_recursionLimit = 0; + + if (!self->m_state) + throw LuaException("Failed to initialize Lua"); + + lua_checkstack(self->m_state, 5); + + // Create handle stack thread and place it in the registry to prevent it from being garbage + // collected. + + self->m_handleThread = lua_newthread(self->m_state); + luaL_ref(self->m_state, LUA_REGISTRYINDEX); + + // We need 1 extra stack space to move values in and out of the handle stack. + self->m_handleStackSize = LUA_MINSTACK - 1; + self->m_handleStackMax = 0; + + // Set the extra space in the lua main state to the pointer to the main + // LuaEngine + *reinterpret_cast<LuaEngine**>(lua_getextraspace(self->m_state)) = self.get(); + + // Create the common message handler function for pcall to print a better + // message with a traceback + lua_pushcfunction(self->m_state, [](lua_State* state) { + // Don't modify the error if it is one of the special limit errrors + if (lua_islightuserdata(state, 1)) { + void* error = lua_touserdata(state, -1); + if (error == &s_luaInstructionLimitExceptionKey || error == &s_luaRecursionLimitExceptionKey) + return 1; + } + + luaL_traceback(state, state, lua_tostring(state, 1), 0); + lua_remove(state, 1); + return 1; + }); + self->m_pcallTracebackMessageHandlerRegistryId = luaL_ref(self->m_state, LUA_REGISTRYINDEX); + + // Create the common metatable for wrapped functions + lua_newtable(self->m_state); + lua_pushcfunction(self->m_state, [](lua_State* state) { + auto func = (LuaDetail::LuaWrappedFunction*)lua_touserdata(state, 1); + func->~function(); + return 0; + }); + LuaDetail::rawSetField(self->m_state, -2, "__gc"); + lua_pushboolean(self->m_state, 0); + LuaDetail::rawSetField(self->m_state, -2, "__metatable"); + self->m_wrappedFunctionMetatableRegistryId = luaL_ref(self->m_state, LUA_REGISTRYINDEX); + + // Create the common metatable for require functions + lua_newtable(self->m_state); + lua_pushcfunction(self->m_state, [](lua_State* state) { + auto func = (LuaContext::RequireFunction*)lua_touserdata(state, 1); + func->~function(); + return 0; + }); + LuaDetail::rawSetField(self->m_state, -2, "__gc"); + lua_pushboolean(self->m_state, 0); + LuaDetail::rawSetField(self->m_state, -2, "__metatable"); + self->m_requireFunctionMetatableRegistryId = luaL_ref(self->m_state, LUA_REGISTRYINDEX); + + // Load all base libraries and prune them of unsafe functions + + luaL_requiref(self->m_state, "_ENV", luaopen_base, true); + if (safe) { + StringSet baseWhitelist = { + "assert", + "error", + "getmetatable", + "ipairs", + "next", + "pairs", + "pcall", + "print", + "rawequal", + "rawget", + "rawlen", + "rawset", + "select", + "setmetatable", + "tonumber", + "tostring", + "type", + "unpack", + "_VERSION", + "xpcall"}; + + lua_pushnil(self->m_state); + while (lua_next(self->m_state, -2) != 0) { + lua_pop(self->m_state, 1); + String key(lua_tostring(self->m_state, -1)); + + if (!baseWhitelist.contains(key)) { + lua_pushvalue(self->m_state, -1); + lua_pushnil(self->m_state); + lua_rawset(self->m_state, -4); + } + } + } + lua_pop(self->m_state, 1); + + luaL_requiref(self->m_state, "os", luaopen_os, true); + if (safe) { + StringSet osWhitelist = {"clock", "difftime", "time"}; + + lua_pushnil(self->m_state); + while (lua_next(self->m_state, -2) != 0) { + lua_pop(self->m_state, 1); + String key(lua_tostring(self->m_state, -1)); + + if (!osWhitelist.contains(key)) { + lua_pushvalue(self->m_state, -1); + lua_pushnil(self->m_state); + lua_rawset(self->m_state, -4); + } + } + } + lua_pop(self->m_state, 1); + + // loads a lua base library, leaves it at the top of the stack + auto loadBaseLibrary = [](lua_State* state, char const* modname, lua_CFunction openf) { + luaL_requiref(state, modname, openf, true); + + // set __metatable metamethod to false + // otherwise scripts can access and mutate the metatable, allowing passing values + // between script contexts, breaking the sandbox + lua_newtable(state); + lua_pushliteral(state, "__metatable"); + lua_pushboolean(state, 0); + lua_rawset(state, -3); + lua_setmetatable(state, -2); + }; + + loadBaseLibrary(self->m_state, "coroutine", luaopen_coroutine); + // replace coroutine resume with one that appends tracebacks + lua_pushliteral(self->m_state, "resume"); + lua_pushcfunction(self->m_state, &LuaEngine::coresumeWithTraceback); + lua_rawset(self->m_state, -3); + + loadBaseLibrary(self->m_state, "math", luaopen_math); + loadBaseLibrary(self->m_state, "string", luaopen_string); + loadBaseLibrary(self->m_state, "table", luaopen_table); + loadBaseLibrary(self->m_state, "utf8", luaopen_utf8); + lua_pop(self->m_state, 5); + + if (!safe) { + loadBaseLibrary(self->m_state, "io", luaopen_io); + loadBaseLibrary(self->m_state, "package", luaopen_package); + loadBaseLibrary(self->m_state, "debug", luaopen_debug); + lua_pop(self->m_state, 3); + } + + // Make a shallow copy of the default script environment and save it for + // resetting the global state. + lua_rawgeti(self->m_state, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS); + lua_newtable(self->m_state); + LuaDetail::shallowCopy(self->m_state, -2, -1); + self->m_scriptDefaultEnvRegistryId = luaL_ref(self->m_state, LUA_REGISTRYINDEX); + lua_pop(self->m_state, 1); + + self->setGlobal("jarray", self->createFunction(&LuaDetail::jarrayCreate)); + self->setGlobal("jobject", self->createFunction(&LuaDetail::jobjectCreate)); + self->setGlobal("jremove", self->createFunction(&LuaDetail::jcontRemove)); + self->setGlobal("jsize", self->createFunction(&LuaDetail::jcontSize)); + self->setGlobal("jresize", self->createFunction(&LuaDetail::jcontResize)); + return self; +} + +LuaEngine::~LuaEngine() { + // If we've had a stack space leak, this will not be zero + starAssert(lua_gettop(m_state) == 0); + lua_close(m_state); +} + +void LuaEngine::setInstructionLimit(uint64_t instructionLimit) { + if (instructionLimit != m_instructionLimit) { + m_instructionLimit = instructionLimit; + updateCountHook(); + } +} + +uint64_t LuaEngine::instructionLimit() const { + return m_instructionLimit; +} + +void LuaEngine::setProfilingEnabled(bool profilingEnabled) { + if (profilingEnabled != m_profilingEnabled) { + m_profilingEnabled = profilingEnabled; + m_profileEntries.clear(); + updateCountHook(); + } +} + +bool LuaEngine::profilingEnabled() const { + return m_profilingEnabled; +} + +List<LuaProfileEntry> LuaEngine::getProfile() { + List<LuaProfileEntry> profileEntries; + for (auto const& p : m_profileEntries) { + profileEntries.append(*p.second); + } + + return profileEntries; +} + +void LuaEngine::setInstructionMeasureInterval(unsigned measureInterval) { + if (measureInterval != m_instructionMeasureInterval) { + m_instructionMeasureInterval = measureInterval; + updateCountHook(); + } +} + +unsigned LuaEngine::instructionMeasureInterval() const { + return m_instructionMeasureInterval; +} + +void LuaEngine::setRecursionLimit(unsigned recursionLimit) { + m_recursionLimit = recursionLimit; +} + +unsigned LuaEngine::recursionLimit() const { + return m_recursionLimit; +} + +ByteArray LuaEngine::compile(char const* contents, size_t size, char const* name) { + lua_checkstack(m_state, 1); + + handleError(m_state, luaL_loadbuffer(m_state, contents, size, name)); + + ByteArray compiledScript; + lua_Writer writer = [](lua_State*, void const* data, size_t size, void* byteArrayPtr) -> int { + ((ByteArray*)byteArrayPtr)->append((char const*)data, size); + return 0; + }; + lua_dump(m_state, writer, &compiledScript, false); + lua_pop(m_state, 1); + + return compiledScript; +} + +ByteArray LuaEngine::compile(String const& contents, String const& name) { + return compile(contents.utf8Ptr(), contents.utf8Size(), name.empty() ? nullptr : name.utf8Ptr()); +} + +ByteArray LuaEngine::compile(ByteArray const& contents, String const& name) { + return compile(contents.ptr(), contents.size(), name.empty() ? nullptr : name.utf8Ptr()); +} + +LuaString LuaEngine::createString(String const& str) { + return createString(str.utf8Ptr()); +} + +LuaString LuaEngine::createString(char const* str) { + lua_checkstack(m_state, 1); + + lua_pushstring(m_state, str); + return LuaString(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +LuaTable LuaEngine::createTable() { + lua_checkstack(m_state, 1); + + lua_newtable(m_state); + return LuaTable(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +LuaThread LuaEngine::createThread() { + lua_checkstack(m_state, 1); + + lua_newthread(m_state); + return LuaThread(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +void LuaEngine::threadPushFunction(int threadIndex, int functionIndex) { + lua_State* thread = lua_tothread(m_handleThread, threadIndex); + + int status = lua_status(thread); + lua_Debug ar; + if (status != LUA_OK || lua_getstack(thread, 0, &ar) > 0 || lua_gettop(thread) > 0) + throw LuaException(strf("Cannot push function to active or errored thread with status %s", status)); + + pushHandle(thread, functionIndex); +} + +LuaThread::Status LuaEngine::threadStatus(int handleIndex) { + lua_State* thread = lua_tothread(m_handleThread, handleIndex); + + int status = lua_status(thread); + if (status != LUA_OK && status != LUA_YIELD) + return LuaThread::Status::Error; + + lua_Debug ar; + if (status == LUA_YIELD || lua_getstack(thread, 0, &ar) > 0 || lua_gettop(thread) > 0) + return LuaThread::Status::Active; + + return LuaThread::Status::Dead; +} + +LuaContext LuaEngine::createContext() { + lua_checkstack(m_state, 2); + + // Create a new blank environment and copy the default environment to it. + lua_newtable(m_state); + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_scriptDefaultEnvRegistryId); + LuaDetail::shallowCopy(m_state, -1, -2); + lua_pop(m_state, 1); + + // Then set that environment as the new context environment in the registry. + return LuaContext(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +void LuaEngine::collectGarbage(Maybe<unsigned> steps) { + for (auto handleIndex : take(m_handleFree)) { + lua_pushnil(m_handleThread); + lua_replace(m_handleThread, handleIndex); + } + + if (steps) + lua_gc(m_state, LUA_GCSTEP, *steps); + else + lua_gc(m_state, LUA_GCCOLLECT, 0); +} + +void LuaEngine::setAutoGarbageCollection(bool autoGarbageColleciton) { + lua_gc(m_state, LUA_GCSTOP, autoGarbageColleciton ? 1 : 0); +} + +void LuaEngine::tuneAutoGarbageCollection(float pause, float stepMultiplier) { + lua_gc(m_state, LUA_GCSETPAUSE, round(pause * 100)); + lua_gc(m_state, LUA_GCSETSTEPMUL, round(stepMultiplier * 100)); +} + +size_t LuaEngine::memoryUsage() const { + return (size_t)lua_gc(m_state, LUA_GCCOUNT, 0) * 1024 + lua_gc(m_state, LUA_GCCOUNTB, 0); +} + +LuaEngine* LuaEngine::luaEnginePtr(lua_State* state) { + return (*reinterpret_cast<LuaEngine**>(lua_getextraspace(state))); +} + +void LuaEngine::countHook(lua_State* state, lua_Debug* ar) { + starAssert(ar->event == LUA_HOOKCOUNT); + lua_checkstack(state, 4); + + auto self = luaEnginePtr(state); + + // If the instruction count is 0, that means in this sequence of calls, + // we have not hit a debug hook yet. Since we don't know the state of + // the internal lua instruction counter at the start, we don't know how + // many instructions have been executed, only that it is >= 1 and <= + // m_instructionMeasureInterval, so we pick the low estimate. + if (self->m_instructionCount == 0) + self->m_instructionCount = 1; + else + self->m_instructionCount += self->m_instructionMeasureInterval; + + if (self->m_instructionLimit != 0 && self->m_instructionCount > self->m_instructionLimit) { + lua_pushlightuserdata(state, &s_luaInstructionLimitExceptionKey); + lua_error(state); + } + + if (self->m_profilingEnabled) { + // find bottom of the stack + // ar will contain the stack info from the last call that returns 1 + int stackLevel = -1; + while (lua_getstack(state, stackLevel + 1, ar) == 1) + stackLevel++; + + shared_ptr<LuaProfileEntry> parentEntry = nullptr; + while (true) { + // Get the 'n' name info and 'S' source info + if (lua_getinfo(state, "nS", ar) == 0) + break; + + auto key = make_tuple(String(ar->short_src), (unsigned)ar->linedefined); + auto& entryMap = parentEntry ? parentEntry->calls : self->m_profileEntries; + if (!entryMap.contains(key)) { + auto e = make_shared<LuaProfileEntry>(); + e->source = ar->short_src; + e->sourceLine = ar->linedefined; + entryMap.set(key, e); + } + auto entry = entryMap.get(key); + + // metadata + if (!entry->name && ar->name) + entry->name = String(ar->name); + if (!entry->nameScope && String() != ar->namewhat) + entry->nameScope = String(ar->namewhat); + + // Only add timeTaken to the self time for the function we are actually + // in, not parent functions + if (stackLevel == 0) { + entry->totalTime += 1; + entry->selfTime += 1; + } else { + entry->totalTime += 1; + } + + parentEntry = entry; + if (lua_getstack(state, --stackLevel, ar) == 0) + break; + } + } +} + +void* LuaEngine::allocate(void*, void* ptr, size_t oldSize, size_t newSize) { + if (newSize == 0) { + Star::free(ptr, oldSize); + return nullptr; + } else { + return Star::realloc(ptr, newSize); + } +} + +void LuaEngine::handleError(lua_State* state, int res) { + if (res != LUA_OK) { + if (lua_islightuserdata(state, -1)) { + void* error = lua_touserdata(state, -1); + if (error == &s_luaInstructionLimitExceptionKey) { + lua_pop(state, 1); + throw LuaInstructionLimitReached(); + } + if (error == &s_luaRecursionLimitExceptionKey) { + lua_pop(state, 1); + throw LuaRecursionLimitReached(); + } + } + + String error; + if (lua_isstring(state, -1)) + error = strf("Error code %s, %s", res, lua_tostring(state, -1)); + else + error = strf("Error code %s, <unknown error>", res); + + lua_pop(state, 1); + + // This seems terrible, but as far as I can tell, this is exactly what the + // stock lua repl does. + if (error.endsWith("<eof>")) + throw LuaIncompleteStatementException(error.takeUtf8()); + else + throw LuaException(error.takeUtf8()); + } +} + +int LuaEngine::pcallWithTraceback(lua_State* state, int nargs, int nresults) { + int msghPosition = lua_gettop(state) - nargs; + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_pcallTracebackMessageHandlerRegistryId); + lua_insert(state, msghPosition); + int ret = lua_pcall(state, nargs, nresults, msghPosition); + lua_remove(state, msghPosition); + return ret; +} + +int LuaEngine::coresumeWithTraceback(lua_State* state) { + lua_State* co = lua_tothread(state, 1); + if (!co) { + lua_checkstack(state, 2); + lua_pushboolean(state, 0); + lua_pushliteral(state, "bad argument #1 to 'resume' (thread expected)"); + return 2; + } + + int args = lua_gettop(state) - 1; + lua_checkstack(co, args); + if (lua_status(co) == LUA_OK && lua_gettop(co) == 0) { + lua_checkstack(state, 2); + lua_pushboolean(state, 0); + lua_pushliteral(state, "cannot resume dead coroutine"); + return 2; + } + + lua_xmove(state, co, args); + int status = lua_resume(co, state, args); + if (status == LUA_OK || status == LUA_YIELD) { + int res = lua_gettop(co); + lua_checkstack(state, res + 1); + lua_pushboolean(state, 1); + lua_xmove(co, state, res); + return res + 1; + } else { + lua_checkstack(state, 2); + lua_pushboolean(state, 0); + propagateErrorWithTraceback(co, state); + return 2; + } +} + +void LuaEngine::propagateErrorWithTraceback(lua_State* from, lua_State* to) { + if (const char* error = lua_tostring(from, -1)) { + luaL_traceback(to, from, error, 0); // error + traceback + lua_pop(from, 1); + } else { + lua_xmove(from, to, 1); // just error, no traceback + } +} + +char const* LuaEngine::stringPtr(int handleIndex) { + return lua_tostring(m_handleThread, handleIndex); +} + +size_t LuaEngine::stringLength(int handleIndex) { + size_t len = 0; + lua_tolstring(m_handleThread, handleIndex, &len); + return len; +} + +LuaValue LuaEngine::tableGet(bool raw, int handleIndex, LuaValue const& key) { + lua_checkstack(m_state, 1); + + pushHandle(m_state, handleIndex); + pushLuaValue(m_state, key); + if (raw) + lua_rawget(m_state, -2); + else + lua_gettable(m_state, -2); + + LuaValue v = popLuaValue(m_state); + lua_pop(m_state, 1); + return v; +} + +LuaValue LuaEngine::tableGet(bool raw, int handleIndex, char const* key) { + lua_checkstack(m_state, 1); + + pushHandle(m_state, handleIndex); + if (raw) + LuaDetail::rawGetField(m_state, -1, key); + else + lua_getfield(m_state, -1, key); + lua_remove(m_state, -2); + return popLuaValue(m_state); +} + +void LuaEngine::tableSet(bool raw, int handleIndex, LuaValue const& key, LuaValue const& value) { + lua_checkstack(m_state, 1); + + pushHandle(m_state, handleIndex); + pushLuaValue(m_state, key); + pushLuaValue(m_state, value); + + if (raw) + lua_rawset(m_state, -3); + else + lua_settable(m_state, -3); + + lua_pop(m_state, 1); +} + +void LuaEngine::tableSet(bool raw, int handleIndex, char const* key, LuaValue const& value) { + lua_checkstack(m_state, 1); + pushHandle(m_state, handleIndex); + pushLuaValue(m_state, value); + + if (raw) + LuaDetail::rawSetField(m_state, -2, key); + else + lua_setfield(m_state, -2, key); + + lua_pop(m_state, 1); +} + +LuaInt LuaEngine::tableLength(bool raw, int handleIndex) { + if (raw) { + return lua_rawlen(m_handleThread, handleIndex); + + } else { + lua_checkstack(m_state, 1); + pushHandle(m_state, handleIndex); + lua_len(m_state, -1); + LuaInt len = lua_tointeger(m_state, -1); + lua_pop(m_state, 2); + return len; + } +} + +void LuaEngine::tableIterate(int handleIndex, function<bool(LuaValue key, LuaValue value)> iterator) { + lua_checkstack(m_state, 4); + + pushHandle(m_state, handleIndex); + lua_pushnil(m_state); + while (lua_next(m_state, -2) != 0) { + lua_pushvalue(m_state, -2); + LuaValue key = popLuaValue(m_state); + LuaValue value = popLuaValue(m_state); + bool cont = false; + try { + cont = iterator(move(key), move(value)); + } catch (...) { + lua_pop(m_state, 2); + throw; + } + if (!cont) { + lua_pop(m_state, 1); + break; + } + } + + lua_pop(m_state, 1); +} + +Maybe<LuaTable> LuaEngine::tableGetMetatable(int handleIndex) { + lua_checkstack(m_state, 2); + + pushHandle(m_state, handleIndex); + if (lua_getmetatable(m_state, -1) == 0) { + lua_pop(m_state, 1); + return {}; + } + LuaTable table = popLuaValue(m_state).get<LuaTable>(); + lua_pop(m_state, 1); + return table; +} + +void LuaEngine::tableSetMetatable(int handleIndex, LuaTable const& table) { + lua_checkstack(m_state, 2); + + pushHandle(m_state, handleIndex); + pushHandle(m_state, table.handleIndex()); + lua_setmetatable(m_state, -2); + lua_pop(m_state, 1); +} + +void LuaEngine::setContextRequire(int handleIndex, LuaContext::RequireFunction requireFunction) { + lua_checkstack(m_state, 4); + + pushHandle(m_state, handleIndex); + + auto funcUserdata = (LuaContext::RequireFunction*)lua_newuserdata(m_state, sizeof(LuaContext::RequireFunction)); + new (funcUserdata) LuaContext::RequireFunction(move(requireFunction)); + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_requireFunctionMetatableRegistryId); + lua_setmetatable(m_state, -2); + + lua_pushvalue(m_state, -2); + + auto invokeRequire = [](lua_State* state) { + try { + lua_checkstack(state, 2); + + auto require = (LuaContext::RequireFunction*)lua_touserdata(state, lua_upvalueindex(1)); + auto self = luaEnginePtr(state); + + auto moduleName = self->luaTo<LuaString>(self->popLuaValue(state)); + + lua_pushvalue(state, lua_upvalueindex(2)); + LuaContext context(LuaDetail::LuaHandle(RefPtr<LuaEngine>(self), self->popHandle(state))); + + (*require)(context, moduleName); + return 0; + } catch (LuaInstructionLimitReached const&) { + lua_pushlightuserdata(state, &s_luaInstructionLimitExceptionKey); + return lua_error(state); + } catch (LuaRecursionLimitReached const&) { + lua_pushlightuserdata(state, &s_luaRecursionLimitExceptionKey); + return lua_error(state); + } catch (std::exception const& e) { + luaL_where(state, 1); + lua_pushstring(state, printException(e, true).c_str()); + lua_concat(state, 2); + return lua_error(state); + } + }; + + lua_pushcclosure(m_state, invokeRequire, 2); + + LuaDetail::rawSetField(m_state, -2, "require"); + + lua_pop(m_state, 1); +} + +void LuaEngine::contextLoad(int handleIndex, char const* contents, size_t size, char const* name) { + lua_checkstack(m_state, 2); + + // First load the script... + handleError(m_state, luaL_loadbuffer(m_state, contents, size, name)); + + // Then set the _ENV upvalue for the newly loaded chunk to our context env so + // we load the scripts into the right environment. + pushHandle(m_state, handleIndex); + lua_setupvalue(m_state, -2, 1); + + incrementRecursionLevel(); + int res = pcallWithTraceback(m_state, 0, 0); + decrementRecursionLevel(); + handleError(m_state, res); +} + +LuaDetail::LuaFunctionReturn LuaEngine::contextEval(int handleIndex, String const& lua) { + int stackSize = lua_gettop(m_state); + lua_checkstack(m_state, 2); + + // First, try interpreting the lua as an expression by adding "return", then + // as a statement. This is the same thing the actual lua repl does. + int loadRes = luaL_loadstring(m_state, ("return " + lua).utf8Ptr()); + if (loadRes == LUA_ERRSYNTAX) { + lua_pop(m_state, 1); + loadRes = luaL_loadstring(m_state, lua.utf8Ptr()); + } + handleError(m_state, loadRes); + + pushHandle(m_state, handleIndex); + lua_setupvalue(m_state, -2, 1); + + incrementRecursionLevel(); + int callRes = pcallWithTraceback(m_state, 0, LUA_MULTRET); + decrementRecursionLevel(); + handleError(m_state, callRes); + + int returnValues = lua_gettop(m_state) - stackSize; + if (returnValues == 0) { + return LuaDetail::LuaFunctionReturn(); + } else if (returnValues == 1) { + return LuaDetail::LuaFunctionReturn(popLuaValue(m_state)); + } else { + LuaVariadic<LuaValue> ret(returnValues); + for (int i = returnValues - 1; i >= 0; --i) + ret[i] = popLuaValue(m_state); + return LuaDetail::LuaFunctionReturn(ret); + } +} + +LuaValue LuaEngine::contextGetPath(int handleIndex, String path) { + lua_checkstack(m_state, 2); + pushHandle(m_state, handleIndex); + + std::string utf8Path = path.takeUtf8(); + char* utf8Ptr = &utf8Path[0]; + size_t utf8Size = utf8Path.size(); + + size_t subPathStart = 0; + for (size_t i = 0; i < utf8Size; ++i) { + if (utf8Path[i] == '.') { + utf8Path[i] = '\0'; + + lua_getfield(m_state, -1, utf8Ptr + subPathStart); + lua_remove(m_state, -2); + + if (lua_type(m_state, -1) != LUA_TTABLE) { + lua_pop(m_state, 1); + return LuaNil; + } + + subPathStart = i + 1; + } + } + + lua_getfield(m_state, -1, utf8Ptr + subPathStart); + lua_remove(m_state, -2); + + return popLuaValue(m_state); +} + +void LuaEngine::contextSetPath(int handleIndex, String path, LuaValue const& value) { + lua_checkstack(m_state, 3); + pushHandle(m_state, handleIndex); + + std::string utf8Path = path.takeUtf8(); + char* utf8Ptr = &utf8Path[0]; + size_t utf8Size = utf8Path.size(); + + size_t subPathStart = 0; + for (size_t i = 0; i < utf8Size; ++i) { + if (utf8Path[i] == '.') { + utf8Path[i] = '\0'; + + int type = lua_getfield(m_state, -1, utf8Ptr + subPathStart); + if (type == LUA_TNIL) { + lua_pop(m_state, 1); + lua_newtable(m_state); + lua_pushvalue(m_state, -1); + lua_setfield(m_state, -3, utf8Ptr + subPathStart); + lua_remove(m_state, -2); + } else if (type == LUA_TTABLE) { + lua_remove(m_state, -2); + } else { + lua_pop(m_state, 2); + throw LuaException("Sub-path in setPath is not nil and is not a table"); + } + + subPathStart = i + 1; + } + } + + pushLuaValue(m_state, value); + lua_setfield(m_state, -2, utf8Ptr + subPathStart); + lua_pop(m_state, 1); +} + +int LuaEngine::popHandle(lua_State* state) { + lua_xmove(state, m_handleThread, 1); + return placeHandle(); +} + +void LuaEngine::pushHandle(lua_State* state, int handleIndex) { + lua_pushvalue(m_handleThread, handleIndex); + lua_xmove(m_handleThread, state, 1); +} + +int LuaEngine::copyHandle(int handleIndex) { + lua_pushvalue(m_handleThread, handleIndex); + return placeHandle(); +} + +int LuaEngine::placeHandle() { + if (auto free = m_handleFree.maybeTakeLast()) { + lua_replace(m_handleThread, *free); + return *free; + + } else { + if (m_handleStackMax >= m_handleStackSize) { + if (!lua_checkstack(m_handleThread, m_handleStackSize)) { + throw LuaException("Exhausted the size of the handle thread stack"); + } + m_handleStackSize *= 2; + } + m_handleStackMax += 1; + return m_handleStackMax; + } +} + +LuaFunction LuaEngine::createWrappedFunction(LuaDetail::LuaWrappedFunction function) { + lua_checkstack(m_state, 2); + + auto funcUserdata = (LuaDetail::LuaWrappedFunction*)lua_newuserdata(m_state, sizeof(LuaDetail::LuaWrappedFunction)); + new (funcUserdata) LuaDetail::LuaWrappedFunction(move(function)); + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_wrappedFunctionMetatableRegistryId); + lua_setmetatable(m_state, -2); + + auto invokeFunction = [](lua_State* state) { + auto func = (LuaDetail::LuaWrappedFunction*)lua_touserdata(state, lua_upvalueindex(1)); + auto self = luaEnginePtr(state); + + int argumentCount = lua_gettop(state); + try { + // For speed, if the argument count is less than some pre-defined + // value, use a stack array. + int const MaxArrayArgs = 8; + LuaDetail::LuaFunctionReturn res; + if (argumentCount <= MaxArrayArgs) { + Array<LuaValue, MaxArrayArgs> args; + for (int i = argumentCount - 1; i >= 0; --i) + args[i] = self->popLuaValue(state); + res = (*func)(*self, argumentCount, args.ptr()); + } else { + List<LuaValue> args(argumentCount); + for (int i = argumentCount - 1; i >= 0; --i) + args[i] = self->popLuaValue(state); + res = (*func)(*self, argumentCount, args.ptr()); + } + + if (auto val = res.ptr<LuaValue>()) { + self->pushLuaValue(state, *val); + return 1; + } else if (auto vec = res.ptr<LuaVariadic<LuaValue>>()) { + for (auto const& r : *vec) + self->pushLuaValue(state, r); + return (int)vec->size(); + } else { + return 0; + } + } catch (LuaInstructionLimitReached const&) { + lua_pushlightuserdata(state, &s_luaInstructionLimitExceptionKey); + return lua_error(state); + } catch (LuaRecursionLimitReached const&) { + lua_pushlightuserdata(state, &s_luaRecursionLimitExceptionKey); + return lua_error(state); + } catch (std::exception const& e) { + luaL_where(state, 1); + lua_pushstring(state, printException(e, true).c_str()); + lua_concat(state, 2); + return lua_error(state); + } + }; + + lua_pushcclosure(m_state, invokeFunction, 1); + + return LuaFunction(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +LuaFunction LuaEngine::createRawFunction(lua_CFunction function) { + lua_checkstack(m_state, 2); + + lua_pushcfunction(m_state, function); + return LuaFunction(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +void LuaEngine::pushLuaValue(lua_State* state, LuaValue const& luaValue) { + lua_checkstack(state, 1); + + struct Pusher { + LuaEngine* engine; + lua_State* state; + + void operator()(LuaNilType const&) { + lua_pushnil(state); + } + + void operator()(LuaBoolean const& b) { + lua_pushboolean(state, b); + } + + void operator()(LuaInt const& i) { + lua_pushinteger(state, i); + } + + void operator()(LuaFloat const& f) { + lua_pushnumber(state, f); + } + + void operator()(LuaReference const& ref) { + if (&ref.engine() != engine) + throw LuaException("lua reference values cannot be shared between engines"); + engine->pushHandle(state, ref.handleIndex()); + } + }; + + luaValue.call(Pusher{this, state}); +} + +LuaValue LuaEngine::popLuaValue(lua_State* state) { + lua_checkstack(state, 1); + + LuaValue result; + starAssert(!lua_isnone(state, -1)); + switch (lua_type(state, -1)) { + case LUA_TNIL: { + lua_pop(state, 1); + break; + } + case LUA_TBOOLEAN: { + result = lua_toboolean(state, -1) != 0; + lua_pop(state, 1); + break; + } + case LUA_TNUMBER: { + if (lua_isinteger(state, -1)) { + result = lua_tointeger(state, -1); + lua_pop(state, 1); + } else { + result = lua_tonumber(state, -1); + lua_pop(state, 1); + } + break; + } + case LUA_TSTRING: { + result = LuaString(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(state))); + break; + } + case LUA_TTABLE: { + result = LuaTable(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(state))); + break; + } + case LUA_TFUNCTION: { + result = LuaFunction(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(state))); + break; + } + case LUA_TTHREAD: { + result = LuaThread(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(state))); + break; + } + case LUA_TUSERDATA: { + if (lua_getmetatable(state, -1) == 0) { + lua_pop(state, 1); + throw LuaException("Userdata in popLuaValue missing metatable"); + } + lua_pop(state, 1); + result = LuaUserData(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(state))); + break; + } + default: { + lua_pop(state, 1); + throw LuaException("Unsupported type in popLuaValue"); + } + } + + return result; +} + +void LuaEngine::incrementRecursionLevel() { + // We reset the instruction count and profiling timing only on the *top + // level* function entrance, not on recursive entrances. + if (m_recursionLevel == 0) { + m_instructionCount = 0; + } + + if (m_recursionLimit != 0 && m_recursionLevel == m_recursionLimit) + throw LuaRecursionLimitReached(); + + ++m_recursionLevel; +} + +void LuaEngine::decrementRecursionLevel() { + starAssert(m_recursionLevel != 0); + --m_recursionLevel; +} + +void LuaEngine::updateCountHook() { + if (m_instructionLimit || m_profilingEnabled) + lua_sethook(m_state, &LuaEngine::countHook, LUA_MASKCOUNT, m_instructionMeasureInterval); + else + lua_sethook(m_state, &LuaEngine::countHook, 0, 0); +} + +int LuaEngine::s_luaInstructionLimitExceptionKey = 0; +int LuaEngine::s_luaRecursionLimitExceptionKey = 0; + +void LuaDetail::rawSetField(lua_State* state, int index, char const* key) { + lua_checkstack(state, 1); + + int absTableIndex = lua_absindex(state, index); + lua_pushstring(state, key); + + // Move the newly pushed key to the secont to top spot, leaving the value in + // the top spot. + lua_insert(state, -2); + + // Pops the value and the key + lua_rawset(state, absTableIndex); +} + +void LuaDetail::rawGetField(lua_State* state, int index, char const* key) { + lua_checkstack(state, 2); + + int absTableIndex = lua_absindex(state, index); + lua_pushstring(state, key); + + // Pops the key + lua_rawget(state, absTableIndex); +} + +void LuaDetail::shallowCopy(lua_State* state, int sourceIndex, int targetIndex) { + lua_checkstack(state, 3); + + int absSourceIndex = lua_absindex(state, sourceIndex); + int absTargetIndex = lua_absindex(state, targetIndex); + + lua_pushnil(state); + while (lua_next(state, absSourceIndex) != 0) { + lua_pushvalue(state, -2); + lua_insert(state, -2); + lua_rawset(state, absTargetIndex); + } +} + +LuaTable LuaDetail::jsonContainerToTable(LuaEngine& engine, Json const& container) { + if (!container.isType(Json::Type::Array) && !container.isType(Json::Type::Object)) + throw LuaException("jsonContainerToTable called on improper json type"); + + auto newIndexMetaMethod = [](LuaTable const& table, LuaValue const& key, LuaValue const& value) { + auto mt = table.getMetatable(); + auto nils = mt->rawGet<LuaTable>("__nils"); + + // If we are setting an entry to nil, need to add a bogus integer entry + // to the __nils table, otherwise need to set the entry *in* the __nils + // table to nil and remove it. + if (value == LuaNil) + nils.rawSet(key, 0); + else + nils.rawSet(key, LuaNil); + table.rawSet(key, value); + }; + + auto mt = engine.createTable(); + auto nils = engine.createTable(); + mt.rawSet("__nils", nils); + mt.rawSet("__newindex", engine.createFunction(newIndexMetaMethod)); + if (container.isType(Json::Type::Array)) + mt.rawSet("__typehint", 1); + else + mt.rawSet("__typehint", 2); + + auto table = engine.createTable(); + table.setMetatable(mt); + + if (container.isType(Json::Type::Array)) { + auto vlist = container.arrayPtr(); + for (size_t i = 0; i < vlist->size(); ++i) { + auto const& val = (*vlist)[i]; + if (val) + table.rawSet(i + 1, val); + else + nils.rawSet(i + 1, 0); + } + } else { + for (auto const& pair : *container.objectPtr()) { + if (pair.second) + table.rawSet(pair.first, pair.second); + else + nils.rawSet(pair.first, 0); + } + } + + return table; +} + +Maybe<Json> LuaDetail::tableToJsonContainer(LuaTable const& table) { + JsonObject stringEntries; + Map<unsigned, Json> intEntries; + int typeHint = 0; + + if (auto mt = table.getMetatable()) { + if (auto th = mt->get<Maybe<int>>("__typehint")) + typeHint = *th; + + if (auto nils = mt->get<Maybe<LuaTable>>("__nils")) { + bool failedConversion = false; + // Nil entries just have a garbage integer as their value + nils->iterate([&](LuaValue const& key, LuaValue const&) { + if (auto i = asInteger(key)) { + intEntries[*i] = Json(); + } else { + if (auto str = table.engine().luaMaybeTo<String>(key)) { + stringEntries[str.take()] = Json(); + } else { + failedConversion = true; + return false; + } + } + return true; + }); + if (failedConversion) + return {}; + } + } + + bool failedConversion = false; + table.iterate([&](LuaValue key, LuaValue value) { + auto jsonValue = table.engine().luaMaybeTo<Json>(value); + if (!jsonValue) { + failedConversion = true; + return false; + } + + if (auto i = asInteger(key)) { + intEntries[*i] = jsonValue.take(); + } else { + auto stringKey = table.engine().luaMaybeTo<String>(move(key)); + if (!stringKey) { + failedConversion = true; + return false; + } + + stringEntries[stringKey.take()] = jsonValue.take(); + } + + return true; + }); + + if (failedConversion) + return {}; + + bool interpretAsList = stringEntries.empty() + && (typeHint == 1 || (typeHint != 2 && !intEntries.empty() && prev(intEntries.end())->first == intEntries.size())); + if (interpretAsList) { + JsonArray list; + for (auto& p : intEntries) + list.set(p.first - 1, move(p.second)); + return Json(move(list)); + } else { + for (auto& p : intEntries) + stringEntries[toString(p.first)] = move(p.second); + return Json(move(stringEntries)); + } +} + +Json LuaDetail::jarrayCreate() { + return JsonArray(); +} + +Json LuaDetail::jobjectCreate() { + return JsonObject(); +} + +void LuaDetail::jcontRemove(LuaTable const& table, LuaValue const& key) { + if (auto mt = table.getMetatable()) { + if (auto nils = mt->rawGet<Maybe<LuaTable>>("__nils")) + nils->rawSet(key, LuaNil); + } + + table.rawSet(key, LuaNil); +} + +size_t LuaDetail::jcontSize(LuaTable const& table) { + size_t elemCount = 0; + size_t highestIndex = 0; + bool hintList = false; + + if (auto mt = table.getMetatable()) { + if (mt->rawGet<Maybe<int>>("__typehint") == 1) + hintList = true; + + if (auto nils = mt->rawGet<Maybe<LuaTable>>("__nils")) { + nils->iterate([&](LuaValue const& key, LuaValue const&) { + auto i = asInteger(key); + if (i && *i >= 0) + highestIndex = max<int>(*i, highestIndex); + else + hintList = false; + ++elemCount; + }); + } + } + + table.iterate([&](LuaValue const& key, LuaValue const&) { + auto i = asInteger(key); + if (i && *i >= 0) + highestIndex = max<int>(*i, highestIndex); + else + hintList = false; + ++elemCount; + }); + + if (hintList) + return highestIndex; + else + return elemCount; +} + +void LuaDetail::jcontResize(LuaTable const& table, size_t targetSize) { + if (auto mt = table.getMetatable()) { + if (auto nils = mt->rawGet<Maybe<LuaTable>>("__nils")) { + nils->iterate([&](LuaValue const& key, LuaValue const&) { + auto i = asInteger(key); + if (i && *i > 0 && (size_t)*i > targetSize) + nils->rawSet(key, LuaNil); + }); + } + } + + table.iterate([&](LuaValue const& key, LuaValue const&) { + auto i = asInteger(key); + if (i && *i > 0 && (size_t)*i > targetSize) + table.rawSet(key, LuaNil); + }); + + table.set(targetSize, table.get(targetSize)); +} + +Maybe<LuaInt> LuaDetail::asInteger(LuaValue const& v) { + if (v.is<LuaInt>()) + return v.get<LuaInt>(); + if (v.is<LuaFloat>()) { + auto f = v.get<LuaFloat>(); + if ((LuaFloat)(LuaInt)f == f) + return (LuaInt)f; + return {}; + } + if (v.is<LuaString>()) + return maybeLexicalCast<LuaInt>(v.get<LuaString>().ptr()); + return {}; +} + +} diff --git a/source/core/StarLua.hpp b/source/core/StarLua.hpp new file mode 100644 index 0000000..1c34cc4 --- /dev/null +++ b/source/core/StarLua.hpp @@ -0,0 +1,2163 @@ +#ifndef STAR_LUA_HPP +#define STAR_LUA_HPP + +#include <typeindex> +#include <type_traits> +#include <lua.hpp> + +#include "StarLexicalCast.hpp" +#include "StarString.hpp" +#include "StarJson.hpp" +#include "StarRefPtr.hpp" + +namespace Star { + +class LuaEngine; +typedef RefPtr<LuaEngine> LuaEnginePtr; + +// Basic unspecified lua exception +STAR_EXCEPTION(LuaException, StarException); + +// Thrown when trying to parse an incomplete statement, useful for implementing +// REPL loops, uses the incomplete statement marker '<eof>' as the standard lua +// repl does. +STAR_EXCEPTION(LuaIncompleteStatementException, LuaException); + +// Thrown when the instruction limit is reached, if the instruction limit is +// set. +STAR_EXCEPTION(LuaInstructionLimitReached, LuaException); + +// Thrown when the engine recursion limit is reached, if the recursion limit is +// set. +STAR_EXCEPTION(LuaRecursionLimitReached, LuaException); + +// Thrown when an incorrect lua type is passed to something in C++ expecting a +// different type. +STAR_EXCEPTION(LuaConversionException, LuaException); + +typedef Empty LuaNilType; +typedef bool LuaBoolean; +typedef lua_Integer LuaInt; +typedef lua_Number LuaFloat; +class LuaString; +class LuaTable; +class LuaFunction; +class LuaThread; +class LuaUserData; +typedef Variant<LuaNilType, LuaBoolean, LuaInt, LuaFloat, LuaString, LuaTable, LuaFunction, LuaThread, LuaUserData> LuaValue; + +// Used to wrap multiple return values from calling a lua function or to pass +// multiple values as arguments to a lua function from a container. If this is +// used as an argument to a lua callback function, it must be the final +// argument of the function! +template <typename T> +class LuaVariadic : public List<T> { +public: + using List<T>::List; +}; + +// Unpack a container and apply each of the arguments separately to a lua +// function, similar to lua's unpack. +template <typename Container> +LuaVariadic<typename std::decay<Container>::type::value_type> luaUnpack(Container&& c); + +// Similar to LuaVariadic, but a tuple type so automatic per-entry type +// conversion is done. This can only be used as the return value of a wrapped +// c++ function, or as a type for the return value of calling a lua function. +template <typename... Types> +class LuaTupleReturn : public tuple<Types...> { +public: + typedef tuple<Types...> Base; + + explicit LuaTupleReturn(Types const&... args); + template <typename... UTypes> + explicit LuaTupleReturn(UTypes&&... args); + template <typename... UTypes> + explicit LuaTupleReturn(UTypes const&... args); + LuaTupleReturn(LuaTupleReturn const& rhs); + LuaTupleReturn(LuaTupleReturn&& rhs); + template <typename... UTypes> + LuaTupleReturn(LuaTupleReturn<UTypes...> const& rhs); + template <typename... UTypes> + LuaTupleReturn(LuaTupleReturn<UTypes...>&& rhs); + + LuaTupleReturn& operator=(LuaTupleReturn const& rhs); + LuaTupleReturn& operator=(LuaTupleReturn&& rhs); + template <typename... UTypes> + LuaTupleReturn& operator=(LuaTupleReturn<UTypes...> const& rhs); + template <typename... UTypes> + LuaTupleReturn& operator=(LuaTupleReturn<UTypes...>&& rhs); +}; + +// std::tie for LuaTupleReturn +template <typename... Types> +LuaTupleReturn<Types&...> luaTie(Types&... args); + +// Constructs a LuaTupleReturn from the given arguments similar to make_tuple +template <typename... Types> +LuaTupleReturn<typename std::decay<Types>::type...> luaTupleReturn(Types&&... args); + +namespace LuaDetail { + struct LuaHandle { + LuaHandle(LuaEnginePtr engine, int handleIndex); + ~LuaHandle(); + + LuaHandle(LuaHandle const& other); + LuaHandle(LuaHandle&& other); + + LuaHandle& operator=(LuaHandle const& other); + LuaHandle& operator=(LuaHandle&& other); + + LuaEnginePtr engine; + int handleIndex; + }; + + // Not meant to be used directly, exposes a raw interface for wrapped C++ + // functions to be wrapped with the least amount of overhead. Arguments are + // passed non-const so that they can be moved into wrapped functions that + // take values without copying. + typedef Variant<LuaValue, LuaVariadic<LuaValue>> LuaFunctionReturn; + typedef function<LuaFunctionReturn(LuaEngine&, size_t argc, LuaValue* argv)> LuaWrappedFunction; +} + +// Prints the lua value similar to lua's print function, except it makes an +// attempt at printing tables. +std::ostream& operator<<(std::ostream& os, LuaValue const& value); + +// Holds a reference to a LuaEngine and a value held internally inside the +// registry of that engine. The lifetime of the LuaEngine will be extended +// until all LuaReferences referencing it are destroyed. +class LuaReference { +public: + LuaReference(LuaDetail::LuaHandle handle); + + LuaReference(LuaReference&&) = default; + LuaReference& operator=(LuaReference&&) = default; + + LuaReference(LuaReference const&) = default; + LuaReference& operator=(LuaReference const&) = default; + + bool operator==(LuaReference const& rhs) const; + bool operator!=(LuaReference const& rhs) const; + + LuaEngine& engine() const; + int handleIndex() const; + +private: + LuaDetail::LuaHandle m_handle; +}; + +class LuaString : public LuaReference { +public: + using LuaReference::LuaReference; + + char const* ptr() const; + size_t length() const; + + String toString() const; +}; + +bool operator==(LuaString const& s1, LuaString const& s2); +bool operator==(LuaString const& s1, char const* s2); +bool operator==(LuaString const& s1, std::string const& s2); +bool operator==(LuaString const& s1, String const& s2); +bool operator==(char const* s1, LuaString const& s2); +bool operator==(std::string const& s1, LuaString const& s2); +bool operator==(String const& s1, LuaString const& s2); + +bool operator!=(LuaString const& s1, LuaString const& s2); +bool operator!=(LuaString const& s1, char const* s2); +bool operator!=(LuaString const& s1, std::string const& s2); +bool operator!=(LuaString const& s1, String const& s2); +bool operator!=(char const* s1, LuaString const& s2); +bool operator!=(std::string const& s1, LuaString const& s2); +bool operator!=(String const& s1, LuaString const& s2); + +class LuaTable : public LuaReference { +public: + using LuaReference::LuaReference; + + template <typename T = LuaValue, typename K> + T get(K key) const; + template <typename T = LuaValue> + T get(char const* key) const; + + template <typename T, typename K> + void set(K key, T t) const; + template <typename T> + void set(char const* key, T t) const; + + // Shorthand for get(path) != LuaNil + template <typename K> + bool contains(K key) const; + bool contains(char const* key) const; + + // Shorthand for setting to LuaNil + template <typename K> + void remove(K key) const; + void remove(char const* key) const; + + // Result of lua # operator + LuaInt length() const; + + // If iteration function returns bool, returning false signals stopping. + template <typename Function> + void iterate(Function&& iterator) const; + + template <typename Return, typename... Args, typename Function> + void iterateWithSignature(Function&& func) const; + + Maybe<LuaTable> getMetatable() const; + void setMetatable(LuaTable const& table) const; + + template <typename T = LuaValue, typename K> + T rawGet(K key) const; + template <typename T = LuaValue> + T rawGet(char const* key) const; + + template <typename T, typename K> + void rawSet(K key, T t) const; + template <typename T> + void rawSet(char const* key, T t) const; + + LuaInt rawLength() const; +}; + +class LuaFunction : public LuaReference { +public: + using LuaReference::LuaReference; + + template <typename Ret = LuaValue, typename... Args> + Ret invoke(Args const&... args) const; +}; + +class LuaThread : public LuaReference { +public: + using LuaReference::LuaReference; + enum class Status { + Dead, + Active, + Error + }; + + // Will return a value if the thread has yielded a value, and nothing if the + // thread has finished execution + template <typename Ret = LuaValue, typename... Args> + Maybe<Ret> resume(Args const&... args) const; + void pushFunction(LuaFunction const& func) const; + Status status() const; +}; + +// Keeping LuaReferences in LuaUserData will lead to circular references to +// LuaEngine, in addition to circular references in Lua which the Lua +// garbage collector can't collect. Don't put LuaReferences in LuaUserData. +class LuaUserData : public LuaReference { +public: + using LuaReference::LuaReference; + + template <typename T> + bool is() const; + + template <typename T> + T& get() const; +}; + +LuaValue const LuaNil = LuaValue(); + +class LuaCallbacks { +public: + template <typename Function> + void registerCallback(String name, Function&& func); + + template <typename Return, typename... Args, typename Function> + void registerCallbackWithSignature(String name, Function&& func); + + LuaCallbacks& merge(LuaCallbacks const& callbacks); + + StringMap<LuaDetail::LuaWrappedFunction> const& callbacks() const; + +private: + StringMap<LuaDetail::LuaWrappedFunction> m_callbacks; +}; + +template <typename T> +class LuaMethods { +public: + template <typename Function> + void registerMethod(String name, Function&& func); + + template <typename Return, typename... Args, typename Function> + void registerMethodWithSignature(String name, Function&& func); + + StringMap<LuaDetail::LuaWrappedFunction> const& methods() const; + +private: + StringMap<LuaDetail::LuaWrappedFunction> m_methods; +}; + +// A single execution context from a LuaEngine that manages a (mostly) distinct +// lua environment. Each LuaContext's global environment is separate and one +// LuaContext can (mostly) not affect any other. +class LuaContext : protected LuaTable { +public: + typedef function<void(LuaContext&, LuaString const&)> RequireFunction; + + using LuaTable::LuaTable; + + using LuaTable::get; + using LuaTable::set; + using LuaTable::contains; + using LuaTable::remove; + using LuaTable::engine; + + // Splits the path by '.' character, so can get / set values in tables inside + // other tables. If any table in the path is not a table but is accessed as + // one, instead returns LuaNil. + template <typename T = LuaValue> + T getPath(String path) const; + // Shorthand for getPath != LuaNil + bool containsPath(String path) const; + // Will create new tables if the key contains paths that are nil + template <typename T> + void setPath(String path, T value); + + // Load the given code (either source or bytecode) into this context as a new + // chunk. It is not necessary to provide the name again if given bytecode. + void load(char const* contents, size_t size, char const* name = nullptr); + void load(String const& contents, String const& name = String()); + void load(ByteArray const& contents, String const& name = String()); + + // Evaluate a piece of lua code in this context, similar to the lua repl. + // Can evaluate both expressions and statements. + template <typename T = LuaValue> + T eval(String const& lua); + + // Override the built-in require function with the given function that takes + // this LuaContext and the module name to load. + void setRequireFunction(RequireFunction requireFunction); + + void setCallbacks(String const& tableName, LuaCallbacks const& callbacks) const; + + // For convenience, invokePath methods are equivalent to calling getPath(key) + // to get a function, and then invoking it. + + template <typename Ret = LuaValue, typename... Args> + Ret invokePath(String const& key, Args const&... args) const; + + // For convenience, calls to LuaEngine conversion / create functions are + // duplicated here. + + template <typename T> + LuaValue luaFrom(T&& t); + template <typename T> + LuaValue luaFrom(T const& t); + template <typename T> + Maybe<T> luaMaybeTo(LuaValue&& v); + template <typename T> + Maybe<T> luaMaybeTo(LuaValue const& v); + template <typename T> + T luaTo(LuaValue const& v); + template <typename T> + T luaTo(LuaValue&& v); + + LuaString createString(String const& str); + LuaString createString(char const* str); + + LuaTable createTable(); + + template <typename Container> + LuaTable createTable(Container const& map); + + template <typename Container> + LuaTable createArrayTable(Container const& array); + + template <typename Function> + LuaFunction createFunction(Function&& func); + + template <typename Return, typename... Args, typename Function> + LuaFunction createFunctionWithSignature(Function&& func); + + template <typename T> + LuaUserData createUserData(T t); +}; + +// Types that want to participate in automatic lua conversion should specialize +// this template and provide static to and from methods on it. The method +// signatures will be called like: +// LuaValue from(LuaEngine& engine, T t); +// Maybe<T> to(LuaEngine& engine, LuaValue v); +// The methods can also take 'T const&' or 'LuaValue const&' as parameters, and +// the 'to' method can also return a bare T if conversion cannot fail. +template <typename T> +struct LuaConverter; + +// UserData types that want to expose methods to lua should specialize this +// template. +template <typename T> +struct LuaUserDataMethods { + static LuaMethods<T> make(); +}; + +// Convenience converter that simply converts to/from LuaUserData, can be +// derived from by a declared converter. +template <typename T> +struct LuaUserDataConverter { + static LuaValue from(LuaEngine& engine, T t); + static Maybe<T> to(LuaEngine& engine, LuaValue const& v); +}; + +struct LuaProfileEntry { + // Source name of the chunk the function was defined in + String source; + // Line number in the chunk of the beginning of the function definition + unsigned sourceLine; + // Name of the function, if it can be determined + Maybe<String> name; + // Scope of the function, if it can be determined + Maybe<String> nameScope; + // Time taken within this function itself + int64_t selfTime; + // Total time taken within this function or sub functions + int64_t totalTime; + // Calls from this function + HashMap<tuple<String, unsigned>, shared_ptr<LuaProfileEntry>> calls; +}; + +// This class represents one execution engine in lua, holding a single +// lua_State. Multiple contexts can be created, and they will have separate +// global environments and cannot affect each other. Individual LuaEngines / +// LuaContexts are not thread safe, use one LuaEngine per thread. +class LuaEngine : public RefCounter { +public: + // If 'safe' is true, then creates a lua engine with all builtin lua + // functions that can affect the real world disabled. + static LuaEnginePtr create(bool safe = true); + + ~LuaEngine(); + + LuaEngine(LuaEngine const&) = delete; + LuaEngine(LuaEngine&&) = default; + + LuaEngine& operator=(LuaEngine const&) = delete; + LuaEngine& operator=(LuaEngine&&) = default; + + // Set the instruction limit for computation sequences in the engine. During + // any function invocation, thread resume, or code evaluation, an instruction + // counter will be started. In the event that the instruction counter + // becomes greater than the given limit, a LuaException will be thrown. The + // count is only reset when the initial entry into LuaEngine is returned, + // recursive entries into LuaEngine accumulate the same instruction counter. + // 0 disables the instruction limit. + void setInstructionLimit(uint64_t instructionLimit = 0); + uint64_t instructionLimit() const; + + // If profiling is enabled, then every 'measureInterval' instructions, the + // function call stack will be recorded, and a summary of function timing can + // be printed using profileReport + void setProfilingEnabled(bool profilingEnabled); + bool profilingEnabled() const; + + // Print a summary of the profiling data gathered since profiling was last + // enabled. + List<LuaProfileEntry> getProfile(); + + // If an instruction limit is set or profiling is neabled, this field + // describes the resolution of instruction count measurement, and affects the + // accuracy of profiling and the instruction count limit. Defaults to 1000 + void setInstructionMeasureInterval(unsigned measureInterval = 1000); + unsigned instructionMeasureInterval() const; + + // Sets the LuaEngine recursion limit, limiting the number of times a + // LuaEngine call may directly or inderectly trigger a call back into the + // LuaEngine, preventing a C++ stack overflow. 0 disables the limit. + void setRecursionLimit(unsigned recursionLimit = 0); + unsigned recursionLimit() const; + + // Compile a given script into bytecode. If name is given, then it will be + // used as the internal name for the resulting chunk and will provide better + // error messages. + // + // Unfortunately the only way to completely ensure that a single script will + // execute in two separate contexts and truly be isolated is to compile the + // script to bytecode and load once in each context as a separate chunk. + ByteArray compile(char const* contents, size_t size, char const* name = nullptr); + ByteArray compile(String const& contents, String const& name = String()); + ByteArray compile(ByteArray const& contents, String const& name = String()); + + // Generic from/to lua conversion, calls template specialization of + // LuaConverter for actual conversion. + template <typename T> + LuaValue luaFrom(T&& t); + template <typename T> + LuaValue luaFrom(T const& t); + template <typename T> + Maybe<T> luaMaybeTo(LuaValue&& v); + template <typename T> + Maybe<T> luaMaybeTo(LuaValue const& v); + + // Wraps luaMaybeTo, throws an exception if conversion fails. + template <typename T> + T luaTo(LuaValue const& v); + template <typename T> + T luaTo(LuaValue&& v); + + LuaString createString(String const& str); + LuaString createString(char const* str); + + LuaTable createTable(); + + template <typename Container> + LuaTable createTable(Container const& map); + + template <typename Container> + LuaTable createArrayTable(Container const& array); + + // Creates a function and deduces the signature of the function using + // FunctionTraits. As a convenience, the given function may optionally take + // a LuaEngine& parameter as the first parameter, and if it does, when called + // the function will get a reference to the calling LuaEngine. + template <typename Function> + LuaFunction createFunction(Function&& func); + + // If the function signature is not deducible using FunctionTraits, you can + // specify the return and argument types manually using this createFunction + // version. + template <typename Return, typename... Args, typename Function> + LuaFunction createFunctionWithSignature(Function&& func); + + LuaThread createThread(); + + template <typename T> + LuaUserData createUserData(T t); + + LuaContext createContext(); + + // Global environment changes only affect newly created contexts + + template <typename T = LuaValue, typename K> + T getGlobal(K key); + template <typename T = LuaValue> + T getGlobal(char const* key); + + template <typename T, typename K> + void setGlobal(K key, T value); + + template <typename T> + void setGlobal(char const* key, T value); + + // Perform either a full or incremental garbage collection. + void collectGarbage(Maybe<unsigned> steps = {}); + + // Stop / start automatic garbage collection + void setAutoGarbageCollection(bool autoGarbageColleciton); + + // Tune the pause and step values of the lua garbage collector + void tuneAutoGarbageCollection(float pause, float stepMultiplier); + + // Bytes in use by lua + size_t memoryUsage() const; + +private: + friend struct LuaDetail::LuaHandle; + friend class LuaReference; + friend class LuaString; + friend class LuaTable; + friend class LuaFunction; + friend class LuaThread; + friend class LuaUserData; + friend class LuaContext; + + LuaEngine() = default; + + // Get the LuaEngine* out of the lua registry magic entry. Uses 1 stack + // space, and does not call lua_checkstack. + static LuaEngine* luaEnginePtr(lua_State* state); + // Counts instructions when instruction limiting is enabled. + static void countHook(lua_State* state, lua_Debug* ar); + + static void* allocate(void* userdata, void* ptr, size_t oldSize, size_t newSize); + + // Pops lua error from stack and throws LuaException + void handleError(lua_State* state, int res); + + // lua_pcall with a better message handler that includes a traceback. + int pcallWithTraceback(lua_State* state, int nargs, int nresults); + + // override for lua coroutine resume with traceback + static int coresumeWithTraceback(lua_State* state); + // propagates errors from one state to another, i.e. past thread boundaries + // pops error off the top of the from stack and pushes onto the to stack + static void propagateErrorWithTraceback(lua_State* from, lua_State* to); + + char const* stringPtr(int handleIndex); + size_t stringLength(int handleIndex); + + LuaValue tableGet(bool raw, int handleIndex, LuaValue const& key); + LuaValue tableGet(bool raw, int handleIndex, char const* key); + + void tableSet(bool raw, int handleIndex, LuaValue const& key, LuaValue const& value); + void tableSet(bool raw, int handleIndex, char const* key, LuaValue const& value); + + LuaInt tableLength(bool raw, int handleIndex); + + void tableIterate(int handleIndex, function<bool(LuaValue, LuaValue)> iterator); + + Maybe<LuaTable> tableGetMetatable(int handleIndex); + void tableSetMetatable(int handleIndex, LuaTable const& table); + + template <typename... Args> + LuaDetail::LuaFunctionReturn callFunction(int handleIndex, Args const&... args); + + template <typename... Args> + Maybe<LuaDetail::LuaFunctionReturn> resumeThread(int handleIndex, Args const&... args); + void threadPushFunction(int threadIndex, int functionIndex); + LuaThread::Status threadStatus(int handleIndex); + + template <typename T> + void registerUserDataType(); + + template <typename T> + bool userDataIsType(int handleIndex); + + template <typename T> + T* getUserData(int handleIndex); + + void setContextRequire(int handleIndex, LuaContext::RequireFunction requireFunction); + + void contextLoad(int handleIndex, char const* contents, size_t size, char const* name); + + LuaDetail::LuaFunctionReturn contextEval(int handleIndex, String const& lua); + + LuaValue contextGetPath(int handleIndex, String path); + void contextSetPath(int handleIndex, String path, LuaValue const& value); + + int popHandle(lua_State* state); + void pushHandle(lua_State* state, int handleIndex); + int copyHandle(int handleIndex); + void destroyHandle(int handleIndex); + + int placeHandle(); + + LuaFunction createWrappedFunction(LuaDetail::LuaWrappedFunction function); + LuaFunction createRawFunction(lua_CFunction func); + + void pushLuaValue(lua_State* state, LuaValue const& luaValue); + LuaValue popLuaValue(lua_State* state); + + template <typename T> + size_t pushArgument(lua_State* state, T const& arg); + + template <typename T> + size_t pushArgument(lua_State* state, LuaVariadic<T> const& args); + + size_t doPushArguments(lua_State*); + template <typename First, typename... Rest> + size_t doPushArguments(lua_State* state, First const& first, Rest const&... rest); + + template <typename... Args> + size_t pushArguments(lua_State* state, Args const&... args); + + void incrementRecursionLevel(); + void decrementRecursionLevel(); + + void updateCountHook(); + + // The following fields exist to use their addresses as unique lightuserdata, + // as is recommended by the lua docs. + static int s_luaInstructionLimitExceptionKey; + static int s_luaRecursionLimitExceptionKey; + + lua_State* m_state; + int m_pcallTracebackMessageHandlerRegistryId; + int m_scriptDefaultEnvRegistryId; + int m_wrappedFunctionMetatableRegistryId; + int m_requireFunctionMetatableRegistryId; + HashMap<std::type_index, int> m_registeredUserDataTypes; + + lua_State* m_handleThread; + int m_handleStackSize; + int m_handleStackMax; + List<int> m_handleFree; + + uint64_t m_instructionLimit; + bool m_profilingEnabled; + unsigned m_instructionMeasureInterval; + uint64_t m_instructionCount; + unsigned m_recursionLevel; + unsigned m_recursionLimit; + HashMap<tuple<String, unsigned>, shared_ptr<LuaProfileEntry>> m_profileEntries; +}; + +// Built in conversions + +template <> +struct LuaConverter<bool> { + static LuaValue from(LuaEngine&, bool v) { + return v; + } + + static Maybe<bool> to(LuaEngine&, LuaValue const& v) { + if (auto b = v.ptr<LuaBoolean>()) + return *b; + if (v == LuaNil) + return false; + return true; + } +}; + +template <typename T> +struct LuaIntConverter { + static LuaValue from(LuaEngine&, T v) { + return LuaInt(v); + } + + static Maybe<T> to(LuaEngine&, LuaValue const& v) { + if (auto n = v.ptr<LuaInt>()) + return *n; + if (auto n = v.ptr<LuaFloat>()) + return *n; + if (auto s = v.ptr<LuaString>()) { + if (auto n = maybeLexicalCast<LuaInt>(s->ptr())) + return *n; + if (auto n = maybeLexicalCast<LuaFloat>(s->ptr())) + return *n; + } + return {}; + } +}; + +template <> +struct LuaConverter<char> : LuaIntConverter<char> {}; + +template <> +struct LuaConverter<unsigned char> : LuaIntConverter<unsigned char> {}; + +template <> +struct LuaConverter<short> : LuaIntConverter<short> {}; + +template <> +struct LuaConverter<unsigned short> : LuaIntConverter<unsigned short> {}; + +template <> +struct LuaConverter<long> : LuaIntConverter<long> {}; + +template <> +struct LuaConverter<unsigned long> : LuaIntConverter<unsigned long> {}; + +template <> +struct LuaConverter<int> : LuaIntConverter<int> {}; + +template <> +struct LuaConverter<unsigned int> : LuaIntConverter<unsigned int> {}; + +template <> +struct LuaConverter<long long> : LuaIntConverter<long long> {}; + +template <> +struct LuaConverter<unsigned long long> : LuaIntConverter<unsigned long long> {}; + +template <typename T> +struct LuaFloatConverter { + static LuaValue from(LuaEngine&, T v) { + return LuaFloat(v); + } + + static Maybe<T> to(LuaEngine&, LuaValue const& v) { + if (auto n = v.ptr<LuaFloat>()) + return *n; + if (auto n = v.ptr<LuaInt>()) + return *n; + if (auto s = v.ptr<LuaString>()) { + if (auto n = maybeLexicalCast<LuaFloat>(s->ptr())) + return *n; + if (auto n = maybeLexicalCast<LuaInt>(s->ptr())) + return *n; + } + return {}; + } +}; + +template <> +struct LuaConverter<float> : LuaFloatConverter<float> {}; + +template <> +struct LuaConverter<double> : LuaFloatConverter<double> {}; + +template <> +struct LuaConverter<String> { + static LuaValue from(LuaEngine& engine, String const& v) { + return engine.createString(v); + } + + static Maybe<String> to(LuaEngine&, LuaValue const& v) { + if (v.is<LuaString>()) + return String(v.get<LuaString>().ptr()); + if (v.is<LuaInt>()) + return String(toString(v.get<LuaInt>())); + if (v.is<LuaFloat>()) + return String(toString(v.get<LuaFloat>())); + return {}; + } +}; + +template <> +struct LuaConverter<std::string> { + static LuaValue from(LuaEngine& engine, std::string const& v) { + return engine.createString(v.c_str()); + } + + static Maybe<std::string> to(LuaEngine& engine, LuaValue v) { + return engine.luaTo<String>(move(v)).takeUtf8(); + } +}; + +template <> +struct LuaConverter<char const*> { + static LuaValue from(LuaEngine& engine, char const* v) { + return engine.createString(v); + } +}; + +template <size_t s> +struct LuaConverter<char[s]> { + static LuaValue from(LuaEngine& engine, char const v[s]) { + return engine.createString(v); + } +}; + +template <> +struct LuaConverter<LuaString> { + static LuaValue from(LuaEngine&, LuaString v) { + return LuaValue(move(v)); + } + + static Maybe<LuaString> to(LuaEngine& engine, LuaValue v) { + if (v.is<LuaString>()) + return LuaString(move(v.get<LuaString>())); + if (v.is<LuaInt>()) + return engine.createString(toString(v.get<LuaInt>())); + if (v.is<LuaFloat>()) + return engine.createString(toString(v.get<LuaFloat>())); + return {}; + } +}; + +template <typename T> +struct LuaValueConverter { + static LuaValue from(LuaEngine&, T v) { + return v; + } + + static Maybe<T> to(LuaEngine&, LuaValue v) { + if (auto p = v.ptr<T>()) { + return move(*p); + } + return {}; + } +}; + +template <> +struct LuaConverter<LuaTable> : LuaValueConverter<LuaTable> {}; + +template <> +struct LuaConverter<LuaFunction> : LuaValueConverter<LuaFunction> {}; + +template <> +struct LuaConverter<LuaThread> : LuaValueConverter<LuaThread> {}; + +template <> +struct LuaConverter<LuaUserData> : LuaValueConverter<LuaUserData> {}; + +template <> +struct LuaConverter<LuaValue> { + static LuaValue from(LuaEngine&, LuaValue v) { + return v; + } + + static LuaValue to(LuaEngine&, LuaValue v) { + return v; + } +}; + +template <typename T> +struct LuaConverter<Maybe<T>> { + static LuaValue from(LuaEngine& engine, Maybe<T> const& v) { + if (v) + return engine.luaFrom<T>(*v); + else + return LuaNil; + } + + static LuaValue from(LuaEngine& engine, Maybe<T>&& v) { + if (v) + return engine.luaFrom<T>(v.take()); + else + return LuaNil; + } + + static Maybe<Maybe<T>> to(LuaEngine& engine, LuaValue const& v) { + if (v != LuaNil) { + if (auto conv = engine.luaMaybeTo<T>(v)) + return conv; + else + return {}; + } else { + return Maybe<T>(); + } + } + + static Maybe<Maybe<T>> to(LuaEngine& engine, LuaValue&& v) { + if (v != LuaNil) { + if (auto conv = engine.luaMaybeTo<T>(move(v))) + return conv; + else + return {}; + } else { + return Maybe<T>(); + } + } +}; + +template <typename T> +struct LuaMapConverter { + static LuaValue from(LuaEngine& engine, T const& v) { + return engine.createTable(v); + } + + static Maybe<T> to(LuaEngine& engine, LuaValue const& v) { + auto table = v.ptr<LuaTable>(); + if (!table) + return {}; + + T result; + bool failed = false; + table->iterate([&result, &failed, &engine](LuaValue key, LuaValue value) { + auto contKey = engine.luaMaybeTo<typename T::key_type>(move(key)); + auto contValue = engine.luaMaybeTo<typename T::mapped_type>(move(value)); + if (!contKey || !contValue) { + failed = true; + return false; + } + result[contKey.take()] = contValue.take(); + return true; + }); + + if (failed) + return {}; + + return result; + } +}; + +template <typename T> +struct LuaContainerConverter { + static LuaValue from(LuaEngine& engine, T const& v) { + return engine.createArrayTable(v); + } + + static Maybe<T> to(LuaEngine& engine, LuaValue const& v) { + auto table = v.ptr<LuaTable>(); + if (!table) + return {}; + + T result; + bool failed = false; + table->iterate([&result, &failed, &engine](LuaValue key, LuaValue value) { + if (!key.is<LuaInt>()) { + failed = true; + return false; + } + auto contVal = engine.luaMaybeTo<typename T::value_type>(move(value)); + if (!contVal) { + failed = true; + return false; + } + result.insert(result.end(), contVal.take()); + return true; + }); + + if (failed) + return {}; + + return result; + } +}; + +template <typename T, typename Allocator> +struct LuaConverter<List<T, Allocator>> : LuaContainerConverter<List<T, Allocator>> {}; + +template <typename T, size_t MaxSize> +struct LuaConverter<StaticList<T, MaxSize>> : LuaContainerConverter<StaticList<T, MaxSize>> {}; + +template <typename T, size_t MaxStackSize> +struct LuaConverter<SmallList<T, MaxStackSize>> : LuaContainerConverter<SmallList<T, MaxStackSize>> {}; + +template <> +struct LuaConverter<StringList> : LuaContainerConverter<StringList> {}; + +template <typename T, typename BaseSet> +struct LuaConverter<Set<T, BaseSet>> : LuaContainerConverter<Set<T, BaseSet>> {}; + +template <typename T, typename BaseSet> +struct LuaConverter<HashSet<T, BaseSet>> : LuaContainerConverter<HashSet<T, BaseSet>> {}; + +template <typename Key, typename Value, typename Compare, typename Allocator> +struct LuaConverter<Map<Key, Value, Compare, Allocator>> : LuaMapConverter<Map<Key, Value, Compare, Allocator>> {}; + +template <typename Key, typename Value, typename Hash, typename Equals, typename Allocator> +struct LuaConverter<HashMap<Key, Value, Hash, Equals, Allocator>> : LuaMapConverter<HashMap<Key, Value, Hash, Equals, Allocator>> {}; + +template <> +struct LuaConverter<Json> { + static LuaValue from(LuaEngine& engine, Json const& v); + static Maybe<Json> to(LuaEngine& engine, LuaValue const& v); +}; + +template <> +struct LuaConverter<JsonObject> { + static LuaValue from(LuaEngine& engine, JsonObject v); + static Maybe<JsonObject> to(LuaEngine& engine, LuaValue v); +}; + +template <> +struct LuaConverter<JsonArray> { + static LuaValue from(LuaEngine& engine, JsonArray v); + static Maybe<JsonArray> to(LuaEngine& engine, LuaValue v); +}; + +namespace LuaDetail { + inline LuaHandle::LuaHandle(LuaEnginePtr engine, int handleIndex) + : engine(move(engine)), handleIndex(handleIndex) {} + + inline LuaHandle::~LuaHandle() { + if (engine) + engine->destroyHandle(handleIndex); + } + + inline LuaHandle::LuaHandle(LuaHandle const& other) { + engine = other.engine; + if (engine) + handleIndex = engine->copyHandle(other.handleIndex); + } + + inline LuaHandle::LuaHandle(LuaHandle&& other) { + engine = take(other.engine); + handleIndex = take(other.handleIndex); + } + + inline LuaHandle& LuaHandle::operator=(LuaHandle const& other) { + if (engine) + engine->destroyHandle(handleIndex); + + engine = other.engine; + if (engine) + handleIndex = engine->copyHandle(other.handleIndex); + + return *this; + } + + inline LuaHandle& LuaHandle::operator=(LuaHandle&& other) { + if (engine) + engine->destroyHandle(handleIndex); + + engine = take(other.engine); + handleIndex = take(other.handleIndex); + + return *this; + } + + template <typename T> + struct FromFunctionReturn { + static T convert(LuaEngine& engine, LuaFunctionReturn const& ret) { + if (auto l = ret.ptr<LuaValue>()) { + return engine.luaTo<T>(*l); + } else if (auto vec = ret.ptr<LuaVariadic<LuaValue>>()) { + return engine.luaTo<T>(vec->at(0)); + } else { + return engine.luaTo<T>(LuaNil); + } + } + }; + + template <typename T> + struct FromFunctionReturn<LuaVariadic<T>> { + static LuaVariadic<T> convert(LuaEngine& engine, LuaFunctionReturn const& ret) { + if (auto l = ret.ptr<LuaValue>()) { + return {engine.luaTo<T>(*l)}; + } else if (auto vec = ret.ptr<LuaVariadic<LuaValue>>()) { + LuaVariadic<T> ret(vec->size()); + for (size_t i = 0; i < vec->size(); ++i) + ret[i] = engine.luaTo<T>((*vec)[i]); + return ret; + } else { + return {}; + } + } + }; + + template <typename ArgFirst, typename... ArgRest> + struct FromFunctionReturn<LuaTupleReturn<ArgFirst, ArgRest...>> { + static LuaTupleReturn<ArgFirst, ArgRest...> convert(LuaEngine& engine, LuaFunctionReturn const& ret) { + if (auto l = ret.ptr<LuaValue>()) { + return doConvertSingle(engine, *l, typename GenIndexSequence<0, sizeof...(ArgRest)>::type()); + } else if (auto vec = ret.ptr<LuaVariadic<LuaValue>>()) { + return doConvertMulti(engine, *vec, typename GenIndexSequence<0, sizeof...(ArgRest)>::type()); + } else { + return doConvertNone(engine, typename GenIndexSequence<0, sizeof...(ArgRest)>::type()); + } + } + + template <size_t... Indexes> + static LuaTupleReturn<ArgFirst, ArgRest...> doConvertSingle( + LuaEngine& engine, LuaValue const& single, IndexSequence<Indexes...> const&) { + return LuaTupleReturn<ArgFirst, ArgRest...>(engine.luaTo<ArgFirst>(single), engine.luaTo<ArgRest>(LuaNil)...); + } + + template <size_t... Indexes> + static LuaTupleReturn<ArgFirst, ArgRest...> doConvertMulti( + LuaEngine& engine, LuaVariadic<LuaValue> const& multi, IndexSequence<Indexes...> const&) { + return LuaTupleReturn<ArgFirst, ArgRest...>( + engine.luaTo<ArgFirst>(multi.at(0)), engine.luaTo<ArgRest>(multi.get(Indexes + 1))...); + } + + template <size_t... Indexes> + static LuaTupleReturn<ArgFirst, ArgRest...> doConvertNone(LuaEngine& engine, IndexSequence<Indexes...> const&) { + return LuaTupleReturn<ArgFirst, ArgRest...>(engine.luaTo<ArgFirst>(LuaNil), engine.luaTo<ArgRest>(LuaNil)...); + } + }; + + template <typename... Args, size_t... Indexes> + LuaVariadic<LuaValue> toVariadicReturn( + LuaEngine& engine, LuaTupleReturn<Args...> const& vals, IndexSequence<Indexes...> const&) { + return LuaVariadic<LuaValue>{engine.luaFrom(get<Indexes>(vals))...}; + } + + template <typename... Args> + LuaVariadic<LuaValue> toWrappedReturn(LuaEngine& engine, LuaTupleReturn<Args...> const& vals) { + return toVariadicReturn(engine, vals, typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + + template <typename T> + LuaVariadic<LuaValue> toWrappedReturn(LuaEngine& engine, LuaVariadic<T> const& vals) { + LuaVariadic<LuaValue> ret(vals.size()); + for (size_t i = 0; i < vals.size(); ++i) + ret[i] = engine.luaFrom(vals[i]); + return ret; + } + + template <typename T> + LuaValue toWrappedReturn(LuaEngine& engine, T const& t) { + return engine.luaFrom(t); + } + + template <typename T> + struct ArgGet { + static T get(LuaEngine& engine, size_t argc, LuaValue* argv, size_t index) { + if (index < argc) + return engine.luaTo<T>(move(argv[index])); + return engine.luaTo<T>(LuaNil); + } + }; + + template <typename T> + struct ArgGet<LuaVariadic<T>> { + static LuaVariadic<T> get(LuaEngine& engine, size_t argc, LuaValue* argv, size_t index) { + if (index >= argc) + return {}; + + LuaVariadic<T> subargs(argc - index); + for (size_t i = index; i < argc; ++i) + subargs[i - index] = engine.luaTo<T>(move(argv[i])); + return subargs; + } + }; + + template <typename Return, typename... Args> + struct FunctionWrapper { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + return toWrappedReturn(engine, (Return const&)func(ArgGet<Args>::get(engine, argc, argv, Indexes)...)); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename... Args> + struct FunctionWrapper<void, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + func(ArgGet<Args>::get(engine, argc, argv, Indexes)...); + return LuaFunctionReturn(); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename Return, typename... Args> + struct FunctionWrapper<Return, LuaEngine, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + return toWrappedReturn(engine, (Return const&)func(engine, ArgGet<Args>::get(engine, argc, argv, Indexes)...)); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename... Args> + struct FunctionWrapper<void, LuaEngine, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + func(engine, ArgGet<Args>::get(engine, argc, argv, Indexes)...); + return LuaFunctionReturn(); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename Return, typename... Args, typename Function> + LuaWrappedFunction wrapFunctionWithSignature(Function&& func) { + return FunctionWrapper<Return, typename std::decay<Args>::type...>::wrap(forward<Function>(func)); + } + + template <typename Return, typename Function, typename... Args> + LuaWrappedFunction wrapFunctionArgs(Function&& func, VariadicTypedef<Args...> const&) { + return wrapFunctionWithSignature<Return, Args...>(forward<Function>(func)); + } + + template <typename Function> + LuaWrappedFunction wrapFunction(Function&& func) { + return wrapFunctionArgs<typename FunctionTraits<Function>::Return>( + forward<Function>(func), typename FunctionTraits<Function>::Args()); + } + + template <typename Return, typename T, typename... Args> + struct MethodWrapper { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) mutable { + if (argc == 0) + throw LuaException("No object argument passed to wrapped method"); + return toWrappedReturn(engine, + (Return const&)func(argv[0].get<LuaUserData>().get<T>(), ArgGet<Args>::get(engine, argc - 1, argv + 1, Indexes)...)); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function&& func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename T, typename... Args> + struct MethodWrapper<void, T, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + if (argc == 0) + throw LuaException("No object argument passed to wrapped method"); + func(argv[0].get<LuaUserData>().get<T>(), ArgGet<Args>::get(engine, argc - 1, argv + 1, Indexes)...); + return LuaFunctionReturn(); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename Return, typename T, typename... Args> + struct MethodWrapper<Return, T, LuaEngine, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + if (argc == 0) + throw LuaException("No object argument passed to wrapped method"); + return toWrappedReturn( + engine, + (Return const&)func(argv[0].get<LuaUserData>().get<T>(), engine, ArgGet<Args>::get(engine, argc - 1, argv + 1, Indexes)...)); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename T, typename... Args> + struct MethodWrapper<void, T, LuaEngine, Args...> { + template <typename Function, size_t... Indexes> + static LuaWrappedFunction wrapIndexes(Function func, IndexSequence<Indexes...> const&) { + return [func = move(func)](LuaEngine& engine, size_t argc, LuaValue* argv) { + if (argc == 0) + throw LuaException("No object argument passed to wrapped method"); + func(argv[0].get<LuaUserData>().get<T>(), engine, ArgGet<Args>::get(engine, argc - 1, argv + 1, Indexes)...); + return LuaValue(); + }; + } + + template <typename Function> + static LuaWrappedFunction wrap(Function func) { + return wrapIndexes(forward<Function>(func), typename GenIndexSequence<0, sizeof...(Args)>::type()); + } + }; + + template <typename Return, typename... Args, typename Function> + LuaWrappedFunction wrapMethodWithSignature(Function&& func) { + return MethodWrapper<Return, typename std::decay<Args>::type...>::wrap(forward<Function>(func)); + } + + template <typename Return, typename Function, typename... Args> + LuaWrappedFunction wrapMethodArgs(Function&& func, VariadicTypedef<Args...> const&) { + return wrapMethodWithSignature<Return, Args...>(forward<Function>(func)); + } + + template <typename Function> + LuaWrappedFunction wrapMethod(Function&& func) { + return wrapMethodArgs<typename FunctionTraits<Function>::Return>( + forward<Function>(func), typename FunctionTraits<Function>::Args()); + } + + template <typename Ret, typename... Args> + struct TableIteratorWrapper; + + template <typename Key, typename Value> + struct TableIteratorWrapper<bool, LuaEngine&, Key, Value> { + template <typename Function> + static function<bool(LuaValue, LuaValue)> wrap(LuaEngine& engine, Function&& func) { + return [&engine, func = move(func)](LuaValue key, LuaValue value) -> bool { + return func(engine, engine.luaTo<Key>(move(key)), engine.luaTo<Value>(move(value))); + }; + } + }; + + template <typename Key, typename Value> + struct TableIteratorWrapper<void, LuaEngine&, Key, Value> { + template <typename Function> + static function<bool(LuaValue, LuaValue)> wrap(LuaEngine& engine, Function&& func) { + return [&engine, func = move(func)](LuaValue key, LuaValue value) -> bool { + func(engine, engine.luaTo<Key>(move(key)), engine.luaTo<Value>(move(value))); + return true; + }; + } + }; + + template <typename Key, typename Value> + struct TableIteratorWrapper<bool, Key, Value> { + template <typename Function> + static function<bool(LuaValue, LuaValue)> wrap(LuaEngine& engine, Function&& func) { + return [&engine, func = move(func)](LuaValue key, LuaValue value) -> bool { + return func(engine.luaTo<Key>(move(key)), engine.luaTo<Value>(move(value))); + }; + } + }; + + template <typename Key, typename Value> + struct TableIteratorWrapper<void, Key, Value> { + template <typename Function> + static function<bool(LuaValue, LuaValue)> wrap(LuaEngine& engine, Function&& func) { + return [&engine, func = move(func)](LuaValue key, LuaValue value) -> bool { + func(engine.luaTo<Key>(move(key)), engine.luaTo<Value>(move(value))); + return true; + }; + } + }; + + template <typename Return, typename... Args, typename Function> + function<bool(LuaValue, LuaValue)> wrapTableIteratorWithSignature(LuaEngine& engine, Function&& func) { + return TableIteratorWrapper<Return, typename std::decay<Args>::type...>::wrap(engine, forward<Function>(func)); + } + + template <typename Return, typename Function, typename... Args> + function<bool(LuaValue, LuaValue)> wrapTableIteratorArgs( + LuaEngine& engine, Function&& func, VariadicTypedef<Args...> const&) { + return wrapTableIteratorWithSignature<Return, Args...>(engine, forward<Function>(func)); + } + + template <typename Function> + function<bool(LuaValue, LuaValue)> wrapTableIterator(LuaEngine& engine, Function&& func) { + return wrapTableIteratorArgs<typename FunctionTraits<Function>::Return>( + engine, forward<Function>(func), typename FunctionTraits<Function>::Args()); + } + + // Like lua_setfield / lua_getfield but raw. + void rawSetField(lua_State* state, int index, char const* key); + void rawGetField(lua_State* state, int index, char const* key); + + // Shallow copies a lua table at the given index into the table at the target + // index. + void shallowCopy(lua_State* state, int sourceIndex, int targetIndex); + + // Creates a custom lua table from a JsonArray or JsonObject that has + // slightly different behavior than a standard lua table. The table + // remembers nil entries, as well as whether it was initially constructed + // from a JsonArray or JsonObject as a hint on how to convert it back into a + // Json. The custom containers are meant to act nearly identical to standard + // lua tables, so iterating over the table with pairs or ipairs works exactly + // like a standard lua table, so will skip over nil entries and in the case + // of ipairs, stop at the first nil entry. + LuaTable jsonContainerToTable(LuaEngine& engine, Json const& container); + + // popJsonContainer must be called with a lua table on the top of the stack. + // Uses the table contents, as well as any hint entries if the table was + // created originally from a Json, to determine whether a JsonArray or + // JsonObject is more appropriate. + Maybe<Json> tableToJsonContainer(LuaTable const& t); + + // Special lua functions to operate on our custom jarray / jobject container + // types. Should always do some "sensible" action if given a regular lua + // table instead of a custom json container one. + + // Create a JsonList container table + Json jarrayCreate(); + + // Create a JsonMap container table + Json jobjectCreate(); + + // *Really* remove an entry from a JsonList or JsonMap container table, + // including removing it from the __nils table. If the given table is not a + // special container table, is equivalent to setting the key entry to nil. + void jcontRemove(LuaTable const& t, LuaValue const& key); + + // Returns the element count of the lua table argument, or, in the case of a + // special JsonList container table, returns the "true" element count + // including any nil entries. + size_t jcontSize(LuaTable const& t); + + // Resize the given lua table by removing any indexed entries greater than the + // target size, and in the case of a special JsonList container table, pads + // to the end of the new size with nil entries. + void jcontResize(LuaTable const& t, size_t size); + + // Coerces a values (strings, floats, ints) into an integer, but fails if the + // number looks fractional (does not parse as int, float is not an exact + // integer) + Maybe<LuaInt> asInteger(LuaValue const& v); +} + +template <typename Container> +LuaVariadic<typename std::decay<Container>::type::value_type> luaUnpack(Container&& c) { + LuaVariadic<typename std::decay<Container>::type::value_type> ret; + if (std::is_rvalue_reference<Container&&>::value) { + for (auto& e : c) + ret.append(move(e)); + } else { + for (auto const& e : c) + ret.append(e); + } + return ret; +} + +template <typename... Types> +LuaTupleReturn<Types...>::LuaTupleReturn(Types const&... args) + : Base(args...) {} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>::LuaTupleReturn(UTypes&&... args) + : Base(move(args)...) {} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>::LuaTupleReturn(UTypes const&... args) + : Base(args...) {} + +template <typename... Types> +LuaTupleReturn<Types...>::LuaTupleReturn(LuaTupleReturn const& rhs) + : Base(rhs) {} + +template <typename... Types> +LuaTupleReturn<Types...>::LuaTupleReturn(LuaTupleReturn&& rhs) + : Base(move(rhs)) {} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>::LuaTupleReturn(LuaTupleReturn<UTypes...> const& rhs) + : Base(rhs) {} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>::LuaTupleReturn(LuaTupleReturn<UTypes...>&& rhs) + : Base(move(rhs)) {} + +template <typename... Types> +LuaTupleReturn<Types...>& LuaTupleReturn<Types...>::operator=(LuaTupleReturn const& rhs) { + Base::operator=(rhs); + return *this; +} + +template <typename... Types> +LuaTupleReturn<Types...>& LuaTupleReturn<Types...>::operator=(LuaTupleReturn&& rhs) { + Base::operator=(move(rhs)); + return *this; +} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>& LuaTupleReturn<Types...>::operator=(LuaTupleReturn<UTypes...> const& rhs) { + Base::operator=((tuple<UTypes...> const&)rhs); + return *this; +} + +template <typename... Types> +template <typename... UTypes> +LuaTupleReturn<Types...>& LuaTupleReturn<Types...>::operator=(LuaTupleReturn<UTypes...>&& rhs) { + Base::operator=((tuple<UTypes...> && )move(rhs)); + return *this; +} + +template <typename... Types> +LuaTupleReturn<Types&...> luaTie(Types&... args) { + return LuaTupleReturn<Types&...>(args...); +} + +template <typename... Types> +LuaTupleReturn<typename std::decay<Types>::type...> luaTupleReturn(Types&&... args) { + return LuaTupleReturn<typename std::decay<Types>::type...>(forward<Types>(args)...); +} + +inline LuaReference::LuaReference(LuaDetail::LuaHandle handle) : m_handle(move(handle)) {} + +inline bool LuaReference::operator==(LuaReference const& rhs) const { + return tie(m_handle.engine, m_handle.handleIndex) == tie(rhs.m_handle.engine, rhs.m_handle.handleIndex); +} + +inline bool LuaReference::operator!=(LuaReference const& rhs) const { + return tie(m_handle.engine, m_handle.handleIndex) != tie(rhs.m_handle.engine, rhs.m_handle.handleIndex); +} + +inline LuaEngine& LuaReference::engine() const { + return *m_handle.engine; +} + +inline int LuaReference::handleIndex() const { + return m_handle.handleIndex; +} + +inline char const* LuaString::ptr() const { + return engine().stringPtr(handleIndex()); +} + +inline size_t LuaString::length() const { + return engine().stringLength(handleIndex()); +} + +inline String LuaString::toString() const { + return String(ptr()); +} + +inline bool operator==(LuaString const& s1, LuaString const& s2) { + return std::strcmp(s1.ptr(), s2.ptr()) == 0; +} + +inline bool operator==(LuaString const& s1, char const* s2) { + return std::strcmp(s1.ptr(), s2) == 0; +} + +inline bool operator==(LuaString const& s1, std::string const& s2) { + return s1.ptr() == s2; +} + +inline bool operator==(LuaString const& s1, String const& s2) { + return s1.ptr() == s2; +} + +inline bool operator==(char const* s1, LuaString const& s2) { + return std::strcmp(s1, s2.ptr()) == 0; +} + +inline bool operator==(std::string const& s1, LuaString const& s2) { + return s1 == s2.ptr(); +} + +inline bool operator==(String const& s1, LuaString const& s2) { + return s1 == s2.ptr(); +} + +inline bool operator!=(LuaString const& s1, LuaString const& s2) { + return !(s1 == s2); +} + +inline bool operator!=(LuaString const& s1, char const* s2) { + return !(s1 == s2); +} + +inline bool operator!=(LuaString const& s1, std::string const& s2) { + return !(s1 == s2); +} + +inline bool operator!=(LuaString const& s1, String const& s2) { + return !(s1 == s2); +} + +inline bool operator!=(char const* s1, LuaString const& s2) { + return !(s1 == s2); +} + +inline bool operator!=(std::string const& s1, LuaString const& s2) { + return !(s1 == s2); +} + +inline bool operator!=(String const& s1, LuaString const& s2) { + return !(s1 == s2); +} + +template <typename T, typename K> +T LuaTable::get(K key) const { + return engine().luaTo<T>(engine().tableGet(false, handleIndex(), engine().luaFrom(move(key)))); +} + +template <typename T> +T LuaTable::get(char const* key) const { + return engine().luaTo<T>(engine().tableGet(false, handleIndex(), key)); +} + +template <typename T, typename K> +void LuaTable::set(K key, T value) const { + engine().tableSet(false, handleIndex(), engine().luaFrom(move(key)), engine().luaFrom(move(value))); +} + +template <typename T> +void LuaTable::set(char const* key, T value) const { + engine().tableSet(false, handleIndex(), key, engine().luaFrom(move(value))); +} + +template <typename K> +bool LuaTable::contains(K key) const { + return engine().tableGet(false, handleIndex(), engine().luaFrom(move(key))) != LuaNil; +} + +template <typename K> +void LuaTable::remove(K key) const { + engine().tableSet(false, handleIndex(), engine().luaFrom(key), LuaValue()); +} + +template <typename Function> +void LuaTable::iterate(Function&& function) const { + return engine().tableIterate(handleIndex(), LuaDetail::wrapTableIterator(engine(), forward<Function>(function))); +} + +template <typename Return, typename... Args, typename Function> +void LuaTable::iterateWithSignature(Function&& func) const { + return engine().tableIterate(handleIndex(), LuaDetail::wrapTableIteratorWithSignature<Return, Args...>(engine(), forward<Function>(func))); +} + +template <typename T, typename K> +T LuaTable::rawGet(K key) const { + return engine().luaTo<T>(engine().tableGet(true, handleIndex(), engine().luaFrom(key))); +} + +template <typename T> +T LuaTable::rawGet(char const* key) const { + return engine().luaTo<T>(engine().tableGet(true, handleIndex(), key)); +} + +template <typename T, typename K> +void LuaTable::rawSet(K key, T value) const { + engine().tableSet(true, handleIndex(), engine().luaFrom(key), engine().luaFrom(value)); +} + +template <typename T> +void LuaTable::rawSet(char const* key, T value) const { + engine().tableSet(true, handleIndex(), engine().luaFrom(key), engine().luaFrom(value)); +} + +template <typename Ret, typename... Args> +Ret LuaFunction::invoke(Args const&... args) const { + return LuaDetail::FromFunctionReturn<Ret>::convert(engine(), engine().callFunction(handleIndex(), args...)); +} + +template <typename Ret, typename... Args> +Maybe<Ret> LuaThread::resume(Args const&... args) const { + auto res = engine().resumeThread(handleIndex(), args...); + if (!res) + return {}; + return LuaDetail::FromFunctionReturn<Ret>::convert(engine(), res.take()); +} + +inline void LuaThread::pushFunction(LuaFunction const& func) const { + engine().threadPushFunction(handleIndex(), func.handleIndex()); +} + +inline LuaThread::Status LuaThread::status() const { + return engine().threadStatus(handleIndex()); +} + +template <typename T> +bool LuaUserData::is() const { + return engine().userDataIsType<T>(handleIndex()); +} + +template <typename T> +T& LuaUserData::get() const { + return *engine().getUserData<T>(handleIndex()); +} + +template <typename Function> +void LuaCallbacks::registerCallback(String name, Function&& func) { + if (!m_callbacks.insert(name, LuaDetail::wrapFunction(forward<Function>(func))).second) + throw LuaException::format("Lua callback '%s' was registered twice", name); +} + +template <typename Return, typename... Args, typename Function> +void LuaCallbacks::registerCallbackWithSignature(String name, Function&& func) { + if (!m_callbacks.insert(name, LuaDetail::wrapFunctionWithSignature<Return, Args...>(forward<Function>(func))).second) + throw LuaException::format("Lua callback '%s' was registered twice", name); +} + +template <typename T> +template <typename Function> +void LuaMethods<T>::registerMethod(String name, Function&& func) { + if (!m_methods.insert(name, LuaDetail::wrapMethod(forward<Function>(move(func)))).second) + throw LuaException::format("Lua method '%s' was registered twice", name); +} + +template <typename T> +template <typename Return, typename... Args, typename Function> +void LuaMethods<T>::registerMethodWithSignature(String name, Function&& func) { + if (!m_methods.insert(name, LuaDetail::wrapMethodWithSignature<Return, Args...>(forward<Function>(move(func)))) + .second) + throw LuaException::format("Lua method '%s' was registered twice", name); +} + +template <typename T> +StringMap<LuaDetail::LuaWrappedFunction> const& LuaMethods<T>::methods() const { + return m_methods; +} + +template <typename T> +T LuaContext::getPath(String path) const { + return engine().luaTo<T>(engine().contextGetPath(handleIndex(), move(path))); +} + +template <typename T> +void LuaContext::setPath(String key, T value) { + engine().contextSetPath(handleIndex(), move(key), engine().luaFrom<T>(move(value))); +} + +template <typename Ret> +Ret LuaContext::eval(String const& lua) { + return LuaDetail::FromFunctionReturn<Ret>::convert(engine(), engine().contextEval(handleIndex(), lua)); +} + +template <typename Ret, typename... Args> +Ret LuaContext::invokePath(String const& key, Args const&... args) const { + auto p = getPath(key); + if (auto f = p.ptr<LuaFunction>()) + return f->invoke<Ret>(args...); + throw LuaException::format("invokePath called on path '%s' which is not function type", key); +} + +template <typename T> +LuaValue LuaContext::luaFrom(T&& t) { + return engine().luaFrom(forward<T>(t)); +} + +template <typename T> +LuaValue LuaContext::luaFrom(T const& t) { + return engine().luaFrom(t); +} + +template <typename T> +Maybe<T> LuaContext::luaMaybeTo(LuaValue&& v) { + return engine().luaFrom(move(v)); +} + +template <typename T> +Maybe<T> LuaContext::luaMaybeTo(LuaValue const& v) { + return engine().luaFrom(v); +} + +template <typename T> +T LuaContext::luaTo(LuaValue&& v) { + return engine().luaTo<T>(move(v)); +} + +template <typename T> +T LuaContext::luaTo(LuaValue const& v) { + return engine().luaTo<T>(v); +} + +template <typename Container> +LuaTable LuaContext::createTable(Container const& map) { + return engine().createTable(map); +} + +template <typename Container> +LuaTable LuaContext::createArrayTable(Container const& array) { + return engine().createArrayTable(array); +} + +template <typename Function> +LuaFunction LuaContext::createFunction(Function&& func) { + return engine().createFunction(forward<Function>(func)); +} + +template <typename Return, typename... Args, typename Function> +LuaFunction LuaContext::createFunctionWithSignature(Function&& func) { + return engine().createFunctionWithSignature<Return, Args...>(forward<Function>(func)); +} + +template <typename T> +LuaUserData LuaContext::createUserData(T t) { + return engine().createUserData(move(t)); +} + +template <typename T> +LuaMethods<T> LuaUserDataMethods<T>::make() { + return LuaMethods<T>(); +} + +template <typename T> +LuaValue LuaUserDataConverter<T>::from(LuaEngine& engine, T t) { + return engine.createUserData(move(t)); +} + +template <typename T> +Maybe<T> LuaUserDataConverter<T>::to(LuaEngine&, LuaValue const& v) { + if (auto ud = v.ptr<LuaUserData>()) { + if (ud->is<T>()) + return ud->get<T>(); + } + return {}; +} + +template <typename T> +LuaValue LuaEngine::luaFrom(T&& t) { + return LuaConverter<typename std::decay<T>::type>::from(*this, forward<T>(t)); +} + +template <typename T> +LuaValue LuaEngine::luaFrom(T const& t) { + return LuaConverter<T>::from(*this, t); +} + +template <typename T> +Maybe<T> LuaEngine::luaMaybeTo(LuaValue&& v) { + return LuaConverter<T>::to(*this, move(v)); +} + +template <typename T> +Maybe<T> LuaEngine::luaMaybeTo(LuaValue const& v) { + return LuaConverter<T>::to(*this, v); +} + +template <typename T> +T LuaEngine::luaTo(LuaValue&& v) { + if (auto res = luaMaybeTo<T>(move(v))) + return res.take(); + throw LuaConversionException::format("Error converting LuaValue to type '%s'", typeid(T).name()); +} + +template <typename T> +T LuaEngine::luaTo(LuaValue const& v) { + if (auto res = luaMaybeTo<T>(v)) + return res.take(); + throw LuaConversionException::format("Error converting LuaValue to type '%s'", typeid(T).name()); +} + +template <typename Container> +LuaTable LuaEngine::createTable(Container const& map) { + auto table = createTable(); + for (auto const& p : map) + table.set(p.first, p.second); + return table; +} + +template <typename Container> +LuaTable LuaEngine::createArrayTable(Container const& array) { + auto table = createTable(); + int i = 1; + for (auto const& elem : array) { + table.set(LuaInt(i), elem); + ++i; + } + return table; +} + +template <typename Function> +LuaFunction LuaEngine::createFunction(Function&& func) { + return createWrappedFunction(LuaDetail::wrapFunction(forward<Function>(func))); +} + +template <typename Return, typename... Args, typename Function> +LuaFunction LuaEngine::createFunctionWithSignature(Function&& func) { + return createWrappedFunction(LuaDetail::wrapFunctionWithSignature<Return, Args...>(forward<Function>(func))); +} + +template <typename... Args> +LuaDetail::LuaFunctionReturn LuaEngine::callFunction(int handleIndex, Args const&... args) { + lua_checkstack(m_state, 1); + + int stackSize = lua_gettop(m_state); + pushHandle(m_state, handleIndex); + + size_t argSize = pushArguments(m_state, args...); + + incrementRecursionLevel(); + int res = pcallWithTraceback(m_state, argSize, LUA_MULTRET); + decrementRecursionLevel(); + handleError(m_state, res); + + int returnValues = lua_gettop(m_state) - stackSize; + if (returnValues == 0) { + return LuaDetail::LuaFunctionReturn(); + } else if (returnValues == 1) { + return popLuaValue(m_state); + } else { + LuaVariadic<LuaValue> ret(returnValues); + for (int i = returnValues - 1; i >= 0; --i) + ret[i] = popLuaValue(m_state); + return ret; + } +} + +template <typename... Args> +Maybe<LuaDetail::LuaFunctionReturn> LuaEngine::resumeThread(int handleIndex, Args const&... args) { + lua_checkstack(m_state, 1); + + pushHandle(m_state, handleIndex); + lua_State* threadState = lua_tothread(m_state, -1); + lua_pop(m_state, 1); + + if (lua_status(threadState) != LUA_YIELD && lua_gettop(threadState) == 0) { + throw LuaException("cannot resume a dead or errored thread"); + } + + size_t argSize = pushArguments(threadState, args...); + incrementRecursionLevel(); + int res = lua_resume(threadState, nullptr, argSize); + decrementRecursionLevel(); + if (res != LUA_OK && res != LUA_YIELD) { + propagateErrorWithTraceback(threadState, m_state); + handleError(m_state, res); + } + + int returnValues = lua_gettop(threadState); + if (returnValues == 0) { + return LuaDetail::LuaFunctionReturn(); + } else if (returnValues == 1) { + return LuaDetail::LuaFunctionReturn(popLuaValue(threadState)); + } else { + LuaVariadic<LuaValue> ret(returnValues); + for (int i = returnValues - 1; i >= 0; --i) + ret[i] = popLuaValue(threadState); + return LuaDetail::LuaFunctionReturn(move(ret)); + } +} + +template <typename T> +void LuaEngine::registerUserDataType() { + if (m_registeredUserDataTypes.contains(typeid(T))) + return; + + lua_checkstack(m_state, 2); + + lua_newtable(m_state); + + // Set the __index on the metatable to itself + lua_pushvalue(m_state, -1); + LuaDetail::rawSetField(m_state, -2, "__index"); + lua_pushboolean(m_state, 0); + LuaDetail::rawSetField(m_state, -2, "__metatable"); // protect metatable + + // Set the __gc function to the userdata destructor + auto gcFunction = [](lua_State* state) { + T& t = *(T*)(lua_touserdata(state, 1)); + t.~T(); + return 0; + }; + lua_pushcfunction(m_state, gcFunction); + LuaDetail::rawSetField(m_state, -2, "__gc"); + + auto methods = LuaUserDataMethods<T>::make(); + for (auto& p : methods.methods()) { + pushLuaValue(m_state, createWrappedFunction(p.second)); + LuaDetail::rawSetField(m_state, -2, p.first.utf8Ptr()); + } + + m_registeredUserDataTypes.add(typeid(T), luaL_ref(m_state, LUA_REGISTRYINDEX)); +} + +template <typename T> +LuaUserData LuaEngine::createUserData(T t) { + registerUserDataType<T>(); + + int typeMetatable = m_registeredUserDataTypes.get(typeid(T)); + + lua_checkstack(m_state, 2); + + new (lua_newuserdata(m_state, sizeof(T))) T(move(t)); + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, typeMetatable); + lua_setmetatable(m_state, -2); + + return LuaUserData(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); +} + +template <typename T, typename K> +T LuaEngine::getGlobal(K key) { + lua_checkstack(m_state, 1); + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_scriptDefaultEnvRegistryId); + pushLuaValue(m_state, luaFrom(move(key))); + lua_rawget(m_state, -2); + + LuaValue v = popLuaValue(m_state); + lua_pop(m_state, 1); + + return luaTo<T>(v); +} + +template <typename T> +T LuaEngine::getGlobal(char const* key) { + lua_checkstack(m_state, 1); + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_scriptDefaultEnvRegistryId); + lua_getfield(m_state, -1, key); + + LuaValue v = popLuaValue(m_state); + lua_pop(m_state, 1); + + return luaTo<T>(v); +} + +template <typename T, typename K> +void LuaEngine::setGlobal(K key, T value) { + lua_checkstack(m_state, 1); + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_scriptDefaultEnvRegistryId); + pushLuaValue(m_state, luaFrom(move(key))); + pushLuaValue(m_state, luaFrom(move(value))); + + lua_rawset(m_state, -3); + lua_pop(m_state, 1); +} + +template <typename T> +void LuaEngine::setGlobal(char const* key, T value) { + lua_checkstack(m_state, 1); + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_scriptDefaultEnvRegistryId); + pushLuaValue(m_state, value); + + lua_setfield(m_state, -2, key); + lua_pop(m_state, 1); +} + +template <typename T> +bool LuaEngine::userDataIsType(int handleIndex) { + int typeRef = m_registeredUserDataTypes.value(typeid(T), LUA_NOREF); + if (typeRef == LUA_NOREF) + return false; + + lua_checkstack(m_state, 3); + + pushHandle(m_state, handleIndex); + if (lua_getmetatable(m_state, -1) == 0) { + lua_pop(m_state, 1); + throw LuaException("Userdata missing metatable in userDataIsType"); + } + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, typeRef); + bool typesEqual = lua_rawequal(m_state, -1, -2); + lua_pop(m_state, 3); + + return typesEqual; +} + +template <typename T> +T* LuaEngine::getUserData(int handleIndex) { + int typeRef = m_registeredUserDataTypes.value(typeid(T), LUA_NOREF); + if (typeRef == LUA_NOREF) + throw LuaException::format("Cannot convert userdata type of %s, not registered", typeid(T).name()); + + lua_checkstack(m_state, 3); + + pushHandle(m_state, handleIndex); + T* userdata = (T*)lua_touserdata(m_state, -1); + if (lua_getmetatable(m_state, -1) == 0) { + lua_pop(m_state, 1); + throw LuaException("Cannot get userdata from lua type, no metatable found"); + } + + lua_rawgeti(m_state, LUA_REGISTRYINDEX, typeRef); + if (!lua_rawequal(m_state, -1, -2)) { + lua_pop(m_state, 3); + throw LuaException::format("Improper conversion from userdata to type %s", typeid(T).name()); + } + + lua_pop(m_state, 3); + + return userdata; +} + +inline void LuaEngine::destroyHandle(int handleIndex) { + // We don't bother setting the entry in the handle stack to nil, we just wait + // for it to be reused and overwritten. We could provide a way to + // periodically ensure that the entire free list is niled out if this becomes + // a memory issue? + m_handleFree.append(handleIndex); +} + +template <typename T> +size_t LuaEngine::pushArgument(lua_State* state, T const& arg) { + pushLuaValue(state, luaFrom(arg)); + return 1; +} + +template <typename T> +size_t LuaEngine::pushArgument(lua_State* state, LuaVariadic<T> const& args) { + // If the argument list was empty then we've checked one extra space on the + // stack, oh well. + if (args.empty()) + return 0; + + // top-level pushArguments does a stack check of the total size of the + // argument list, for variadic arguments, it could be more than one + // argument so check the stack for the arguments in the variadic list minus + // one. + lua_checkstack(state, args.size() - 1); + for (size_t i = 0; i < args.size(); ++i) + pushLuaValue(state, luaFrom(args[i])); + return args.size(); +} + +inline size_t LuaEngine::doPushArguments(lua_State*) { + return 0; +} + +template <typename First, typename... Rest> +size_t LuaEngine::doPushArguments(lua_State* state, First const& first, Rest const&... rest) { + size_t s = pushArgument(state, first); + return s + doPushArguments(state, rest...); +} + +template <typename... Args> +size_t LuaEngine::pushArguments(lua_State* state, Args const&... args) { + lua_checkstack(state, sizeof...(args)); + return doPushArguments(state, args...); +} + +} + +#endif diff --git a/source/core/StarLuaConverters.cpp b/source/core/StarLuaConverters.cpp new file mode 100644 index 0000000..e09876a --- /dev/null +++ b/source/core/StarLuaConverters.cpp @@ -0,0 +1,42 @@ +#include "StarLuaConverters.hpp" +#include "StarColor.hpp" + +namespace Star { + +LuaValue LuaConverter<Color>::from(LuaEngine& engine, Color const& c) { + if (c.alpha() == 255) + return engine.createArrayTable(initializer_list<uint8_t>{c.red(), c.green(), c.blue()}); + else + return engine.createArrayTable(initializer_list<uint8_t>{c.red(), c.green(), c.blue(), c.alpha()}); +} + +Maybe<Color> LuaConverter<Color>::to(LuaEngine& engine, LuaValue const& v) { + if (auto t = v.ptr<LuaTable>()) { + Color c = Color::rgba(0, 0, 0, 255); + Maybe<int> r = engine.luaMaybeTo<int>(t->get(1)); + Maybe<int> g = engine.luaMaybeTo<int>(t->get(2)); + Maybe<int> b = engine.luaMaybeTo<int>(t->get(3)); + if (!r || !g || !b) + return {}; + + c.setRed(*r); + c.setGreen(*g); + c.setBlue(*b); + + if (Maybe<int> a = engine.luaMaybeTo<int>(t->get(4))) { + if (!a) + return {}; + c.setAlpha(*a); + } + + return c; + } else if (auto s = v.ptr<LuaString>()) { + try { + return Color(s->ptr()); + } catch (ColorException) {} + } + + return {}; +} + +} diff --git a/source/core/StarLuaConverters.hpp b/source/core/StarLuaConverters.hpp new file mode 100644 index 0000000..bf030cd --- /dev/null +++ b/source/core/StarLuaConverters.hpp @@ -0,0 +1,264 @@ +#ifndef STAR_LUA_CONVERTERS_HPP +#define STAR_LUA_CONVERTERS_HPP + +#include "StarRect.hpp" +#include "StarVector.hpp" +#include "StarColor.hpp" +#include "StarPoly.hpp" +#include "StarLine.hpp" +#include "StarLua.hpp" +#include "StarVariant.hpp" + +namespace Star { + +template <typename T1, typename T2> +struct LuaConverter<pair<T1, T2>> { + static LuaValue from(LuaEngine& engine, pair<T1, T2>&& v) { + auto t = engine.createTable(); + t.set(1, move(v.first)); + t.set(2, move(v.second)); + return t; + } + + static LuaValue from(LuaEngine& engine, pair<T1, T2> const& v) { + auto t = engine.createTable(); + t.set(1, v.first); + t.set(2, v.second); + return t; + } + + static Maybe<pair<T1, T2>> to(LuaEngine& engine, LuaValue const& v) { + if (auto table = engine.luaMaybeTo<LuaTable>(move(v))) { + auto p1 = engine.luaMaybeTo<T1>(table->get(1)); + auto p2 = engine.luaMaybeTo<T2>(table->get(2)); + if (p1 && p2) + return {{p1.take(), p2.take()}}; + } + return {}; + } +}; + +template <typename T, size_t N> +struct LuaConverter<Vector<T, N>> { + static LuaValue from(LuaEngine& engine, Vector<T, N> const& v) { + return engine.createArrayTable(v); + } + + static Maybe<Vector<T, N>> to(LuaEngine& engine, LuaValue const& v) { + auto table = v.ptr<LuaTable>(); + if (!table) + return {}; + + Vector<T, N> vec; + for (size_t i = 0; i < N; ++i) { + auto v = engine.luaMaybeTo<T>(table->get(i + 1)); + if (!v) + return {}; + vec[i] = *v; + } + return vec; + } +}; + +template <typename T> +struct LuaConverter<Matrix3<T>> { + static LuaValue from(LuaEngine& engine, Matrix3<T> const& m) { + auto table = engine.createTable(); + table.set(1, luaFrom(engine, m.row1)); + table.set(2, luaFrom(engine, m.row2)); + table.set(3, luaFrom(engine, m.row3)); + return table; + } + + static Maybe<Matrix3<T>> to(LuaEngine& engine, LuaValue const& v) { + if (auto table = v.ptr<LuaTable>()) { + auto r1 = engine.luaMaybeTo<typename Matrix3<T>::Vec3>(table->get(1)); + auto r2 = engine.luaMaybeTo<typename Matrix3<T>::Vec3>(table->get(2)); + auto r3 = engine.luaMaybeTo<typename Matrix3<T>::Vec3>(table->get(3)); + if (r1 && r2 && r3) + return Matrix3<T>(*r1, *r2, *r3); + } + return {}; + } +}; + +template <typename T> +struct LuaConverter<Rect<T>> { + static LuaValue from(LuaEngine& engine, Rect<T> const& v) { + if (v.isNull()) + return LuaNil; + auto t = engine.createTable(); + t.set(1, v.xMin()); + t.set(2, v.yMin()); + t.set(3, v.xMax()); + t.set(4, v.yMax()); + return t; + } + + static Maybe<Rect<T>> to(LuaEngine& engine, LuaValue const& v) { + if (v == LuaNil) + return Rect<T>::null(); + + if (auto table = v.ptr<LuaTable>()) { + auto xMin = engine.luaMaybeTo<T>(table->get(1)); + auto yMin = engine.luaMaybeTo<T>(table->get(2)); + auto xMax = engine.luaMaybeTo<T>(table->get(3)); + auto yMax = engine.luaMaybeTo<T>(table->get(4)); + if (xMin && yMin && xMax && yMax) + return Rect<T>(*xMin, *yMin, *xMax, *yMax); + } + return {}; + } +}; + +template <typename T> +struct LuaConverter<Polygon<T>> { + static LuaValue from(LuaEngine& engine, Polygon<T> const& poly) { + return engine.createArrayTable(poly.vertexes()); + } + + static Maybe<Polygon<T>> to(LuaEngine& engine, LuaValue const& v) { + if (auto points = engine.luaMaybeTo<typename Polygon<T>::VertexList>(v)) + return Polygon<T>(points.take()); + return {}; + } +}; + +template <typename T, size_t N> +struct LuaConverter<Line<T, N>> { + static LuaValue from(LuaEngine& engine, Line<T, N> const& line) { + auto table = engine.createTable(); + table.set(1, line.min()); + table.set(2, line.max()); + return table; + } + + static Maybe<Line<T, N>> to(LuaEngine& engine, LuaValue const& v) { + if (auto table = v.ptr<LuaTable>()) { + auto min = engine.luaMaybeTo<Vector<T, N>>(table->get(1)); + auto max = engine.luaMaybeTo<Vector<T, N>>(table->get(2)); + if (min && max) + return Line<T, N>(*min, *max); + } + return {}; + } +}; + +// Sort of magical converter, tries to convert from all the types in the +// Variant in order, returning the first correct type. Types should not be +// ambiguous, or the more specific types should come first, which relies on the +// implementation of the converters. +template <typename FirstType, typename... RestTypes> +struct LuaConverter<Variant<FirstType, RestTypes...>> { + static LuaValue from(LuaEngine& engine, Variant<FirstType, RestTypes...> const& variant) { + return variant.call([&engine](auto const& a) { return luaFrom(engine, a); }); + } + + static LuaValue from(LuaEngine& engine, Variant<FirstType, RestTypes...>&& variant) { + return variant.call([&engine](auto& a) { return luaFrom(engine, move(a)); }); + } + + static Maybe<Variant<FirstType, RestTypes...>> to(LuaEngine& engine, LuaValue const& v) { + return checkTypeTo<FirstType, RestTypes...>(engine, v); + } + + template <typename CheckType1, typename CheckType2, typename... CheckTypeRest> + static Maybe<Variant<FirstType, RestTypes...>> checkTypeTo(LuaEngine& engine, LuaValue const& v) { + if (auto t1 = engine.luaMaybeTo<CheckType1>(v)) + return t1; + else + return checkTypeTo<CheckType2, CheckTypeRest...>(engine, v); + } + + template <typename Type> + static Maybe<Variant<FirstType, RestTypes...>> checkTypeTo(LuaEngine& engine, LuaValue const& v) { + return engine.luaMaybeTo<Type>(v); + } + + static Maybe<Variant<FirstType, RestTypes...>> to(LuaEngine& engine, LuaValue&& v) { + return checkTypeTo<FirstType, RestTypes...>(engine, move(v)); + } + + template <typename CheckType1, typename CheckType2, typename... CheckTypeRest> + static Maybe<Variant<FirstType, RestTypes...>> checkTypeTo(LuaEngine& engine, LuaValue&& v) { + if (auto t1 = engine.luaMaybeTo<CheckType1>(v)) + return t1; + else + return checkTypeTo<CheckType2, CheckTypeRest...>(engine, move(v)); + } + + template <typename Type> + static Maybe<Variant<FirstType, RestTypes...>> checkTypeTo(LuaEngine& engine, LuaValue&& v) { + return engine.luaMaybeTo<Type>(move(v)); + } +}; + +// Similarly to Variant converter, tries to convert from all types in order. +// An empty MVariant is converted to nil and vice versa. +template <typename... Types> +struct LuaConverter<MVariant<Types...>> { + static LuaValue from(LuaEngine& engine, MVariant<Types...> const& variant) { + LuaValue value; + variant.call([&value, &engine](auto const& a) { + value = luaFrom(engine, a); + }); + return value; + } + + static LuaValue from(LuaEngine& engine, MVariant<Types...>&& variant) { + LuaValue value; + variant.call([&value, &engine](auto& a) { + value = luaFrom(engine, move(a)); + }); + return value; + } + + static Maybe<MVariant<Types...>> to(LuaEngine& engine, LuaValue const& v) { + if (v == LuaNil) + return MVariant<Types...>(); + return checkTypeTo<Types...>(engine, v); + } + + template <typename CheckType1, typename CheckType2, typename... CheckTypeRest> + static Maybe<MVariant<Types...>> checkTypeTo(LuaEngine& engine, LuaValue const& v) { + if (auto t1 = engine.luaMaybeTo<CheckType1>(v)) + return t1; + else + return checkTypeTo<CheckType2, CheckTypeRest...>(engine, v); + } + + template <typename CheckType> + static Maybe<MVariant<Types...>> checkTypeTo(LuaEngine& engine, LuaValue const& v) { + return engine.luaMaybeTo<CheckType>(v); + } + + static Maybe<MVariant<Types...>> to(LuaEngine& engine, LuaValue&& v) { + if (v == LuaNil) + return MVariant<Types...>(); + return checkTypeTo<Types...>(engine, move(v)); + } + + template <typename CheckType1, typename CheckType2, typename... CheckTypeRest> + static Maybe<MVariant<Types...>> checkTypeTo(LuaEngine& engine, LuaValue&& v) { + if (auto t1 = engine.luaMaybeTo<CheckType1>(v)) + return t1; + else + return checkTypeTo<CheckType2, CheckTypeRest...>(engine, move(v)); + } + + template <typename CheckType> + static Maybe<MVariant<Types...>> checkTypeTo(LuaEngine& engine, LuaValue&& v) { + return engine.luaMaybeTo<CheckType>(move(v)); + } + +}; + +template <> +struct LuaConverter<Color> { + static LuaValue from(LuaEngine& engine, Color const& c); + static Maybe<Color> to(LuaEngine& engine, LuaValue const& v); +}; + +} + +#endif diff --git a/source/core/StarMap.hpp b/source/core/StarMap.hpp new file mode 100644 index 0000000..45dd219 --- /dev/null +++ b/source/core/StarMap.hpp @@ -0,0 +1,318 @@ +#ifndef STAR_MAP_HPP +#define STAR_MAP_HPP + +#include <map> +#include <unordered_map> + +#include "StarFlatHashMap.hpp" +#include "StarList.hpp" + +namespace Star { + +STAR_EXCEPTION(MapException, StarException); + +template <typename BaseMap> +class MapMixin : public BaseMap { +public: + typedef BaseMap Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + + typedef typename Base::key_type key_type; + typedef typename Base::mapped_type mapped_type; + typedef typename Base::value_type value_type; + + typedef typename std::decay<mapped_type>::type* mapped_ptr; + typedef typename std::decay<mapped_type>::type const* mapped_const_ptr; + + template <typename MapType> + static MapMixin from(MapType const& m); + + using Base::Base; + + List<key_type> keys() const; + List<mapped_type> values() const; + List<pair<key_type, mapped_type>> pairs() const; + + bool contains(key_type const& k) const; + + // Removes the item with key k and returns true if contains(k) is true, + // false otherwise. + bool remove(key_type const& k); + + // Removes *all* items that have a value matching the given one. Returns + // true if any elements were removed. + bool removeValues(mapped_type const& v); + + // Throws exception if key not found + mapped_type take(key_type const& k); + + Maybe<mapped_type> maybeTake(key_type const& k); + + // Throws exception if key not found + mapped_type& get(key_type const& k); + mapped_type const& get(key_type const& k) const; + + // Return d if key not found + mapped_type value(key_type const& k, mapped_type d = mapped_type()) const; + + Maybe<mapped_type> maybe(key_type const& k) const; + + mapped_const_ptr ptr(key_type const& k) const; + mapped_ptr ptr(key_type const& k); + + // Finds first value matching the given value and returns its key. + key_type keyOf(mapped_type const& v) const; + + // Finds all of the values matching the given value and returns their keys. + List<key_type> keysOf(mapped_type const& v) const; + + bool hasValue(mapped_type const& v) const; + + using Base::insert; + + // Same as insert(value_type), returns the iterator to either the newly + // inserted value or the existing value, and then a bool that is true if the + // new element was inserted. + pair<iterator, bool> insert(key_type k, mapped_type v); + + // Add a key / value pair, throw if the key already exists + mapped_type& add(key_type k, mapped_type v); + + // Set a key to a value, always override if it already exists + mapped_type& set(key_type k, mapped_type v); + + // Appends all values of given map into this map. If overwite is false, then + // skips values that already exist in this map. Returns false if any keys + // previously existed. + template <typename MapType> + bool merge(MapType const& m, bool overwrite = false); + + bool operator==(MapMixin const& m) const; +}; + +template <typename BaseMap> +std::ostream& operator<<(std::ostream& os, MapMixin<BaseMap> const& m); + +template <typename Key, typename Value, typename Compare = std::less<Key>, typename Allocator = std::allocator<pair<Key const, Value>>> +using Map = MapMixin<std::map<Key, Value, Compare, Allocator>>; + +template <typename Key, typename Value, typename Hash = hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = std::allocator<pair<Key const, Value>>> +using HashMap = MapMixin<FlatHashMap<Key, Value, Hash, Equals, Allocator>>; + +template <typename Key, typename Value, typename Hash = hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = std::allocator<pair<Key const, Value>>> +using StableHashMap = MapMixin<std::unordered_map<Key, Value, Hash, Equals, Allocator>>; + +template <typename BaseMap> +template <typename MapType> +auto MapMixin<BaseMap>::from(MapType const& m) -> MapMixin { + return MapMixin(m.begin(), m.end()); +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::keys() const -> List<key_type> { + List<key_type> klist; + for (const_iterator i = Base::begin(); i != Base::end(); ++i) + klist.push_back(i->first); + return klist; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::values() const -> List<mapped_type> { + List<mapped_type> vlist; + for (const_iterator i = Base::begin(); i != Base::end(); ++i) + vlist.push_back(i->second); + return vlist; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::pairs() const -> List<pair<key_type, mapped_type>> { + List<pair<key_type, mapped_type>> plist; + for (const_iterator i = Base::begin(); i != Base::end(); ++i) + plist.push_back(*i); + return plist; +} + +template <typename BaseMap> +bool MapMixin<BaseMap>::contains(key_type const& k) const { + return Base::find(k) != Base::end(); +} + +template <typename BaseMap> +bool MapMixin<BaseMap>::remove(key_type const& k) { + return Base::erase(k) != 0; +} + +template <typename BaseMap> +bool MapMixin<BaseMap>::removeValues(mapped_type const& v) { + bool removed = false; + const_iterator i = Base::begin(); + while (i != Base::end()) { + if (i->second == v) { + Base::erase(i++); + removed = true; + } else { + ++i; + } + } + return removed; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::take(key_type const& k) -> mapped_type { + if (auto v = maybeTake(k)) + return v.take(); + throw MapException(strf("Key '%s' not found in Map::take()", outputAny(k))); +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::maybeTake(key_type const& k) -> Maybe<mapped_type> { + const_iterator i = Base::find(k); + if (i != Base::end()) { + mapped_type v = std::move(i->second); + Base::erase(i); + return move(v); + } + + return {}; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::get(key_type const& k) -> mapped_type& { + iterator i = Base::find(k); + if (i == Base::end()) + throw MapException(strf("Key '%s' not found in Map::get()", outputAny(k))); + return i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::get(key_type const& k) const -> mapped_type const& { + const_iterator i = Base::find(k); + if (i == Base::end()) + throw MapException(strf("Key '%s' not found in Map::get()", outputAny(k))); + return i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::value(key_type const& k, mapped_type d) const -> mapped_type { + const_iterator i = Base::find(k); + if (i == Base::end()) + return std::move(d); + else + return i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::maybe(key_type const& k) const -> Maybe<mapped_type> { + auto i = Base::find(k); + if (i == Base::end()) + return {}; + else + return i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::ptr(key_type const& k) const -> mapped_const_ptr { + auto i = Base::find(k); + if (i == Base::end()) + return nullptr; + else + return &i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::ptr(key_type const& k) -> mapped_ptr { + auto i = Base::find(k); + if (i == Base::end()) + return nullptr; + else + return &i->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::keyOf(mapped_type const& v) const -> key_type { + for (const_iterator i = Base::begin(); i != Base::end(); ++i) { + if (i->second == v) + return i->first; + } + throw MapException(strf("Value '%s' not found in Map::keyOf()", outputAny(v))); +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::keysOf(mapped_type const& v) const -> List<key_type> { + List<key_type> keys; + for (const_iterator i = Base::begin(); i != Base::end(); ++i) { + if (i->second == v) + keys.append(i->first); + } + return keys; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::hasValue(mapped_type const& v) const -> bool { + for (const_iterator i = Base::begin(); i != Base::end(); ++i) { + if (i->second == v) + return true; + } + return false; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::insert(key_type k, mapped_type v) -> pair<iterator, bool> { + return Base::insert(value_type(move(k), move(v))); +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::add(key_type k, mapped_type v) -> mapped_type& { + auto pair = Base::insert(value_type(move(k), move(v))); + if (!pair.second) + throw MapException(strf("Entry with key '%s' already present.", outputAny(k))); + else + return pair.first->second; +} + +template <typename BaseMap> +auto MapMixin<BaseMap>::set(key_type k, mapped_type v) -> mapped_type& { + auto i = Base::find(k); + if (i != Base::end()) { + i->second = move(v); + return i->second; + } else { + return Base::insert(value_type(move(k), move(v))).first->second; + } +} + +template <typename BaseMap> +template <typename OtherMapType> +bool MapMixin<BaseMap>::merge(OtherMapType const& m, bool overwrite) { + return mapMerge(*this, m, overwrite); +} + +template <typename BaseMap> +bool MapMixin<BaseMap>::operator==(MapMixin const& m) const { + return this == &m || mapsEqual(*this, m); +} + +template <typename MapType> +void printMap(std::ostream& os, MapType const& m) { + os << "{ "; + for (auto i = m.begin(); i != m.end(); ++i) { + if (m.begin() == i) + os << "\""; + else + os << ", \""; + os << i->first << "\" : \"" << i->second << "\""; + } + os << " }"; +} + +template <typename BaseMap> +std::ostream& operator<<(std::ostream& os, MapMixin<BaseMap> const& m) { + printMap(os, m); + return os; +} + +} + +#endif diff --git a/source/core/StarMathCommon.hpp b/source/core/StarMathCommon.hpp new file mode 100644 index 0000000..27d2976 --- /dev/null +++ b/source/core/StarMathCommon.hpp @@ -0,0 +1,328 @@ +#ifndef STAR_MATH_COMMON_HPP +#define STAR_MATH_COMMON_HPP + +#include <type_traits> +#include <limits> + +#include "StarMaybe.hpp" + +namespace Star { + +STAR_EXCEPTION(MathException, StarException); + +namespace Constants { + double const pi = 3.14159265358979323846; + double const rad2deg = 57.2957795130823208768; + double const deg2rad = 1 / rad2deg; + double const sqrt2 = 1.41421356237309504880; + double const log2e = 1.44269504088896340736; +} + +// Really common std namespace includes, and replacements for std libraries +// that don't provide them + +using std::abs; +using std::fabs; +using std::sqrt; +using std::floor; +using std::ceil; +using std::round; +using std::fmod; +using std::sin; +using std::cos; +using std::tan; +using std::pow; +using std::atan2; +using std::log; +using std::log10; +using std::copysign; + +inline float log2(float f) { + return log(f) * (float)Constants::log2e; +} + +inline double log2(double d) { + return log(d) * Constants::log2e; +} + +// Count the number of '1' bits in the given unsigned integer +template <typename Int> +typename std::enable_if<std::is_integral<Int>::value && std::is_unsigned<Int>::value, unsigned>::type countSetBits(Int value) { + unsigned count = 0; + while (value != 0) { + value &= (value - 1); + ++count; + } + return count; +} + +template <typename T, typename T2> +typename std::enable_if<!std::numeric_limits<T>::is_integer && !std::numeric_limits<T2>::is_integer && sizeof(T) >= sizeof(T2), bool>::type +nearEqual(T x, T2 y, unsigned ulp) { + auto epsilon = std::numeric_limits<T>::epsilon(); + return abs(x - y) <= epsilon * max(abs(x), (T)abs(y)) * ulp; +} + +template <typename T, typename T2> +typename std::enable_if<!std::numeric_limits<T>::is_integer && !std::numeric_limits<T2>::is_integer && sizeof(T) < sizeof(T2), bool>::type +nearEqual(T x, T2 y, unsigned ulp) { + return nearEqual(y, x, ulp); +} + +template <typename T, typename T2> +typename std::enable_if<std::numeric_limits<T>::is_integer && !std::numeric_limits<T2>::is_integer, bool>::type +nearEqual(T x, T2 y, unsigned ulp) { + return nearEqual((double)x, y, ulp); +} + +template <typename T, typename T2> +typename std::enable_if<!std::numeric_limits<T>::is_integer && std::numeric_limits<T2>::is_integer, bool>::type +nearEqual(T x, T2 y, unsigned ulp) { + return nearEqual(x, (double)y, ulp); +} + +template <typename T, typename T2> +typename std::enable_if<std::numeric_limits<T>::is_integer && std::numeric_limits<T2>::is_integer, bool>::type +nearEqual(T x, T2 y, unsigned) { + return x == y; +} + +template <typename T, typename T2> +bool nearEqual(T x, T2 y) { + return nearEqual(x, y, 1); +} + +template <typename T> +typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type nearZero(T x, unsigned ulp = 2) { + return abs(x) <= std::numeric_limits<T>::min() * ulp; +} + +template <typename T> +typename std::enable_if<std::numeric_limits<T>::is_integer, bool>::type nearZero(T x) { + return x == 0; +} + +template <typename T> +constexpr T lowest() { + return std::numeric_limits<T>::lowest(); +} + +template <typename T> +constexpr T highest() { + return std::numeric_limits<T>::max(); +} + +template <typename T> +constexpr T square(T const& x) { + return x * x; +} + +template <typename T> +constexpr T cube(T const& x) { + return x * x * x; +} + +template <typename Float> +int ipart(Float f) { + return (int)floor(f); +} + +template <typename Float> +Float fpart(Float f) { + return f - ipart(f); +} + +template <typename Float> +Float rfpart(Float f) { + return 1.0 - fpart(f); +} + +template <typename T, typename T2> +T clampMagnitude(T const& v, T2 const& mag) { + if (v > mag) + return mag; + else if (v < -mag) + return -mag; + else + return v; +} + +template <typename T> +T clamp(T const val, T const min, T const max) { + return std::min(std::max(val, min), max); +} + +template <typename T> +T clampDynamic(T const val, T const a, T const b) { + return std::min(std::max(val, std::min(a, b)), std::max(a, b)); +} + +template <typename IntType, typename PowType> +IntType intPow(IntType i, PowType p) { + starAssert(p >= 0); + + if (p == 0) + return 1; + if (p == 1) + return i; + + IntType tmp = intPow(i, p / 2); + if ((p % 2) == 0) + return tmp * tmp; + else + return i * tmp * tmp; +} + +template <typename Int> +bool isPowerOf2(Int x) { + if (x < 1) + return false; + return (x & (x - 1)) == 0; +} + +inline uint64_t ceilPowerOf2(uint64_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + +template <typename Float> +Float sigmoid(Float x) { + return 1 / (1 + std::exp(-x)); +} + +// returns a % m such that the answer is always positive. +// For example, -1 mod 10 is 9. +template <typename IntType> +IntType pmod(IntType a, IntType m) { + IntType r = a % m; + return r < 0 ? r + m : r; +} + +// Same as pmod but for float like values. +template <typename Float> +Float pfmod(Float a, Float m) { + if (m == 0) + return a; + + return a - m * floor(a / m); +} + +// Finds the *smallest* distance (in absolute value terms) from b to a (a - b) +// in a non-euclidean wrapping number line. Suppose size is 100, wrapDiff(10, +// 109) would return 1, because 509 is congruent to the point 9. On the other +// hand, wrapDiff(10, 111) would return -1, because 111 is congruent to the +// point 11. +template <typename Type> +Type wrapDiff(Type a, Type b, Type size) { + a = pmod(a, size); + b = pmod(b, size); + + Type diff = a - b; + if (diff > size / 2) + diff -= size; + else if (diff < -size / 2) + diff += size; + + return diff; +} + +// Sampe as wrapDiff but for float like values +template <typename Type> +Type wrapDiffF(Type a, Type b, Type size) { + a = pfmod(a, size); + b = pfmod(b, size); + + Type diff = a - b; + if (diff > size / 2) + diff -= size; + else if (diff < -size / 2) + diff += size; + + return diff; +} + +// like std::pow, except ignores sign, and the return value will match the sign +// of the value passed in. ppow(-2, 2) == -4 +template <typename Float> +Float ppow(Float val, Float pow) { + return copysign(std::pow(std::fabs(val), pow), val); +} + +// Returns angle wrapped around to the range [-pi, pi). +template <typename Float> +Float constrainAngle(Float angle) { + angle = fmod((Float)(angle + Constants::pi), (Float)(Constants::pi * 2)); + if (angle < 0) + angle += Constants::pi * 2; + return angle - Constants::pi; +} + +// Returns the closest angle movement to go from the given angle to the target +// angle, in radians. +template <typename Float> +Float angleDiff(Float angle, Float targetAngle) { + double diff = fmod((Float)(targetAngle - angle + Constants::pi), (Float)(Constants::pi * 2)); + if (diff < 0) + diff += Constants::pi * 2; + return diff - Constants::pi; +} + +// Approach the given goal value from the current value, at a maximum rate of +// change. Rate should always be a positive value. (T must be signed). +template <typename T> +T approach(T goal, T current, T rate) { + if (goal < current) { + return max(current - rate, goal); + } else if (goal > current) { + return min(current + rate, goal); + } else { + return current; + } +} + +// Same as approach, specialied for angles, and always approaches from the +// closest absolute direction. +template <typename T> +T approachAngle(T goal, T current, T rate) { + return constrainAngle(current + clampMagnitude<T>(angleDiff(current, goal), rate)); +} + +// Used in color conversion from floating point to uint8_t +inline uint8_t floatToByte(float val, bool doClamp = false) { + if (doClamp) + val = clamp(val, 0.0f, 1.0f); + return (uint8_t)(val * 255.0f); +} + +// Used in color conversion from uint8_t to normalized float. +inline float byteToFloat(uint8_t val) { + return val / 255.0f; +} + +// Turn a randomized floating point value from [0.0, 1.0] to [-1.0, 1.0] +template <typename Float> +Float randn(Float val) { + return val * 2 - 1; +} + +// Increments a value between min and max inclusive, cycling around to min when +// it would be incremented beyond max. If the value is outside of the range, +// the next increment will start at min. +template <typename Integer> +Integer cycleIncrement(Integer val, Integer min, Integer max) { + if (val < min || val >= max) + return min; + else + return val + 1; +} + +} + +#endif diff --git a/source/core/StarMatrix3.hpp b/source/core/StarMatrix3.hpp new file mode 100644 index 0000000..6c688c8 --- /dev/null +++ b/source/core/StarMatrix3.hpp @@ -0,0 +1,456 @@ +#ifndef STAR_MATRIX3_HPP +#define STAR_MATRIX3_HPP + +#include "StarVector.hpp" + +namespace Star { + +template <typename T> +class Matrix3 { +public: + typedef Vector<T, 3> Vec3; + typedef Vector<T, 2> Vec2; + typedef Array<Vec3, 3> Rows; + + // Only enable pointer access if we know that our internal rows are not + // padded + template <typename RT = void> + using EnableIfContiguousStorage = + typename std::enable_if<sizeof(Vec3) == 3 * sizeof(T) && sizeof(Rows) == 3 * sizeof(Vec3), RT>::type; + + static Matrix3 identity(); + + // Construct an affine 2d transform + static Matrix3 rotation(T angle, Vec2 const& point = Vec2()); + static Matrix3 translation(Vec2 const& point); + static Matrix3 scaling(T scale, Vec2 const& point = Vec2()); + static Matrix3 scaling(Vec2 const& scale, Vec2 const& point = Vec2()); + + Matrix3(); + + Matrix3(T r1c1, T r1c2, T r1c3, T r2c1, T r2c2, T r2c3, T r3c1, T r3c2, T r3c3); + + Matrix3(Vec3 const& r1, Vec3 const& r2, Vec3 const& r3); + + Matrix3(T const* ptr); + template <typename T2> + Matrix3(Matrix3<T2> const& m); + + template <typename T2> + Matrix3& operator=(Matrix3<T2> const& m); + + // Row-major indexing + Vec3& operator[](size_t const i); + Vec3 const& operator[](size_t const i) const; + + // Gives pointer to row major storage + EnableIfContiguousStorage<T*> ptr(); + EnableIfContiguousStorage<T const*> ptr() const; + + // Copy to an existing array + void copy(T* loc) const; + + Vec3 row(size_t i) const; + template <typename T2> + void setRow(size_t i, Vector<T2, 3> const& v); + + Vec3 col(size_t i); + template <typename T2> + void setCol(size_t i, Vector<T2, 3> const& v); + + T determinant() const; + Vec3 trace() const; + Matrix3 inverse() const; + bool isOrthogonal(T tolerance) const; + + void transpose(); + void orthogonalize(); + void invert(); + + // Apply the given 2d affine transformation to this matrix in global + // coordinates + void rotate(T angle, Vec2 const& point = Vec2()); + void translate(Vec2 const& point); + void scale(Vec2 const& scale, Vec2 const& point = Vec2()); + void scale(T scale, Vec2 const& point = Vec2()); + + // Do an affine transformation of the given 2d vector. + template <typename T2> + Vector<T2, 2> transformVec2(Vector<T2, 2> const& v2) const; + + // The resulting angle of a transformation on any ray with this angle. + float transformAngle(float angle) const; + + bool operator==(Matrix3 const& m2) const; + bool operator!=(Matrix3 const& m2) const; + + Matrix3& operator*=(T const& s); + Matrix3& operator/=(T const& s); + Matrix3 operator*(T const& s) const; + Matrix3 operator/(T const& s) const; + Matrix3 operator-() const; + + template <typename T2> + Matrix3& operator+=(Matrix3<T2> const& m2); + + template <typename T2> + Matrix3& operator-=(Matrix3<T2> const& m2); + + template <typename T2> + Matrix3& operator*=(Matrix3<T2> const& m2); + + template <typename T2> + Matrix3 operator+(Matrix3<T2> const& m2) const; + + template <typename T2> + Matrix3 operator-(Matrix3<T2> const& m2) const; + + template <typename T2> + Matrix3 operator*(Matrix3<T2> const& m2) const; + + template <typename T2> + Vec3 operator*(Vector<T2, 3> const& v) const; + +private: + Rows m_rows; +}; + +typedef Matrix3<float> Mat3F; +typedef Matrix3<double> Mat3D; + +template <typename T> +Matrix3<T> Matrix3<T>::identity() { + return Matrix3(1, 0, 0, 0, 1, 0, 0, 0, 1); +} + +template <typename T> +Matrix3<T> Matrix3<T>::rotation(T angle, Vec2 const& point) { + T s = sin(angle); + T c = cos(angle); + return Matrix3(c, -s, point[0] - c * point[0] + s * point[1], s, c, point[1] - s * point[0] - c * point[1], 0, 0, 1); +} + +template <typename T> +Matrix3<T> Matrix3<T>::translation(Vec2 const& point) { + return Matrix3(1, 0, point[0], 0, 1, point[1], 0, 0, 1); +} + +template <typename T> +Matrix3<T> Matrix3<T>::scaling(T scale, Vec2 const& point) { + return scaling(Vec2::filled(scale), point); +} + +template <typename T> +Matrix3<T> Matrix3<T>::scaling(Vec2 const& scale, Vec2 const& point) { + return Matrix3(scale[0], 0, point[0] - point[0] * scale[0], 0, scale[1], point[1] - point[1] * scale[1], 0, 0, 1); +} + +template <typename T> +Matrix3<T>::Matrix3() {} + +template <typename T> +Matrix3<T>::Matrix3(T r1c1, T r1c2, T r1c3, T r2c1, T r2c2, T r2c3, T r3c1, T r3c2, T r3c3) + : m_rows(Vec3(r1c1, r1c2, r1c3), Vec3(r2c1, r2c2, r2c3), Vec3(r3c1, r3c2, r3c3)) {} + +template <typename T> +Matrix3<T>::Matrix3(const Vec3& r1, const Vec3& r2, const Vec3& r3) + : m_rows{r1, r2, r3} {} + +template <typename T> +Matrix3<T>::Matrix3(T const* ptr) + : m_rows{Vec3(ptr), Vec3(ptr + 3), Vec3(ptr + 6)} {} + +template <typename T> +template <typename T2> +Matrix3<T>::Matrix3(const Matrix3<T2>& m) { + *this = m; +} + +template <typename T> +template <typename T2> +Matrix3<T>& Matrix3<T>::operator=(const Matrix3<T2>& m) { + m_rows = m.m_rows; + return *this; +} + +template <typename T> +auto Matrix3<T>::operator[](const size_t i) -> Vec3 & { + return m_rows[i]; +} + +template <typename T> +auto Matrix3<T>::operator[](const size_t i) const -> Vec3 const & { + return m_rows[i]; +} + +template <typename T> +auto Matrix3<T>::ptr() -> EnableIfContiguousStorage<T*> { + return m_rows[0].ptr(); +} + +template <typename T> +auto Matrix3<T>::ptr() const -> EnableIfContiguousStorage<T const*> { + return m_rows[0].ptr(); +} + +template <typename T> +void Matrix3<T>::copy(T* loc) const { + m_rows[0].copyFrom(loc); + m_rows[1].copyFrom(loc + 3); + m_rows[2].copyFrom(loc + 6); +} + +template <typename T> +auto Matrix3<T>::row(size_t i) const -> Vec3 { + return operator[](i); +} + +template <typename T> +template <typename T2> +void Matrix3<T>::setRow(size_t i, const Vector<T2, 3>& v) { + operator[](i) = Vec3(v); +} + +template <typename T> +auto Matrix3<T>::col(size_t i) -> Vec3 { + return Vec3(m_rows[0][i], m_rows[1][i], m_rows[2][i]); +} + +template <typename T> +template <typename T2> +void Matrix3<T>::setCol(size_t i, const Vector<T2, 3>& v) { + m_rows[0][i] = T(v[0]); + m_rows[1][i] = T(v[1]); + m_rows[2][i] = T(v[2]); +} + +template <typename T> +T Matrix3<T>::determinant() const { + return m_rows[0][0] * m_rows[1][1] * m_rows[2][2] - m_rows[0][0] * m_rows[2][1] * m_rows[1][2] + + m_rows[1][0] * m_rows[2][1] * m_rows[0][2] - m_rows[1][0] * m_rows[0][1] * m_rows[2][2] + + m_rows[2][0] * m_rows[0][1] * m_rows[1][2] - m_rows[2][0] * m_rows[1][1] * m_rows[0][2]; +} + +template <typename T> +void Matrix3<T>::transpose() { + std::swap(m_rows[1][0], m_rows[0][1]); + std::swap(m_rows[2][0], m_rows[0][2]); + std::swap(m_rows[2][1], m_rows[1][2]); +} + +template <typename T> +void Matrix3<T>::invert() { + T d = determinant(); + + m_rows[0][0] = (m_rows[1][1] * m_rows[2][2] - m_rows[1][2] * m_rows[2][1]) / d; + m_rows[0][1] = -(m_rows[0][1] * m_rows[2][2] - m_rows[0][2] * m_rows[2][1]) / d; + m_rows[0][2] = (m_rows[0][1] * m_rows[1][2] - m_rows[0][2] * m_rows[1][1]) / d; + m_rows[1][0] = -(m_rows[1][0] * m_rows[2][2] - m_rows[1][2] * m_rows[2][0]) / d; + m_rows[1][1] = (m_rows[0][0] * m_rows[2][2] - m_rows[0][2] * m_rows[2][0]) / d; + m_rows[1][2] = -(m_rows[0][0] * m_rows[1][2] - m_rows[0][2] * m_rows[1][0]) / d; + m_rows[2][0] = (m_rows[1][0] * m_rows[2][1] - m_rows[1][1] * m_rows[2][0]) / d; + m_rows[2][1] = -(m_rows[0][0] * m_rows[2][1] - m_rows[0][1] * m_rows[2][0]) / d; + m_rows[2][2] = (m_rows[0][0] * m_rows[1][1] - m_rows[0][1] * m_rows[1][0]) / d; +} + +template <typename T> +Matrix3<T> Matrix3<T>::inverse() const { + auto m = *this; + m.invert(); + return m; +} + +template <typename T> +void Matrix3<T>::orthogonalize() { + m_rows[0].normalize(); + T dot = m_rows[0] * m_rows[1]; + m_rows[1][0] -= m_rows[0][0] * dot; + m_rows[1][1] -= m_rows[0][1] * dot; + m_rows[1][2] -= m_rows[0][2] * dot; + m_rows[1].normalize(); + + dot = m_rows[1] * m_rows[2]; + m_rows[2][0] -= m_rows[1][0] * dot; + m_rows[2][1] -= m_rows[1][1] * dot; + m_rows[2][2] -= m_rows[1][2] * dot; + m_rows[2].normalize(); +} + +template <typename T> +bool Matrix3<T>::isOrthogonal(T tolerance) const { + T det = determinant(); + return std::fabs(det - 1) < tolerance || std::fabs(det + 1) < tolerance; +} + +template <typename T> +void Matrix3<T>::rotate(T angle, Vec2 const& point) { + *this = rotation(angle, point) * *this; +} + +template <typename T> +void Matrix3<T>::translate(Vec2 const& point) { + *this = translation(point) * *this; +} + +template <typename T> +void Matrix3<T>::scale(Vec2 const& scale, Vec2 const& point) { + *this = scaling(scale, point) * *this; +} + +template <typename T> +void Matrix3<T>::scale(T scale, Vec2 const& point) { + *this = scaling(scale, point) * *this; +} + +template <typename T> +template <typename T2> +Vector<T2, 2> Matrix3<T>::transformVec2(Vector<T2, 2> const& point) const { + Vector<T2, 3> res = (*this) * Vector<T2, 3>(point, 1); + return res.vec2(); +} + +template <typename T> +float Matrix3<T>::transformAngle(float angle) const { + Vec2 a = Vec2::withAngle(angle, 1.0f); + Matrix3 m = *this; + m[0][2] = 0; + m[1][2] = 0; + return m.transformVec2(a).angle(); +} + +template <typename T> +bool Matrix3<T>::operator==(Matrix3 const& m2) const { + return tie(m_rows[0], m_rows[1], m_rows[2]) == tie(m2.m_rows[0], m2.m_rows[1], m2.m_rows[2]); +} + +template <typename T> +bool Matrix3<T>::operator!=(Matrix3 const& m2) const { + return tie(m_rows[0], m_rows[1], m_rows[2]) != tie(m2.m_rows[0], m2.m_rows[1], m2.m_rows[2]); +} + +template <typename T> +Matrix3<T>& Matrix3<T>::operator*=(const T& s) { + m_rows[0] *= s; + m_rows[1] *= s; + m_rows[2] *= s; + return *this; +} + +template <typename T> +Matrix3<T>& Matrix3<T>::operator/=(const T& s) { + m_rows[0] /= s; + m_rows[1] /= s; + m_rows[2] /= s; + return *this; +} + +template <typename T> +auto Matrix3<T>::trace() const -> Vec3 { + return Vec3(m_rows[0][0], m_rows[1][1], m_rows[2][2]); +} + +template <typename T> +Matrix3<T> Matrix3<T>::operator-() const { + return Matrix3(-m_rows[0], -m_rows[1], -m_rows[2]); +} + +template <typename T> +template <typename T2> +Matrix3<T>& Matrix3<T>::operator+=(const Matrix3<T2>& m) { + m_rows[0] += m[0]; + m_rows[1] += m[1]; + m_rows[2] += m[2]; + return *this; +} + +template <typename T> +template <typename T2> +Matrix3<T>& Matrix3<T>::operator-=(const Matrix3<T2>& m) { + m_rows[0] -= m[0]; + m_rows[1] -= m[1]; + m_rows[2] -= m[2]; + return *this; +} + +template <typename T> +template <typename T2> +Matrix3<T>& Matrix3<T>::operator*=(Matrix3<T2> const& m2) { + *this = *this * m2; + return *this; +} + +template <typename T> +template <typename T2> +Matrix3<T> Matrix3<T>::operator+(const Matrix3<T2>& m2) const { + return Matrix3<T>(m_rows[0] + m2[0], m_rows[1] + m2[1], m_rows[2] + m2[2]); +} + +template <typename T> +template <typename T2> +Matrix3<T> Matrix3<T>::operator-(const Matrix3<T2>& m2) const { + return Matrix3<T>(m_rows[0] - m2[0], m_rows[1] - m2[1], m_rows[2] - m2[2]); +} + +template <typename T> +template <typename T2> +Matrix3<T> Matrix3<T>::operator*(const Matrix3<T2>& m2) const { + return Matrix3<T>(m_rows[0][0] * m2[0][0] + m_rows[0][1] * m2[1][0] + m_rows[0][2] * m2[2][0], + m_rows[0][0] * m2[0][1] + m_rows[0][1] * m2[1][1] + m_rows[0][2] * m2[2][1], + m_rows[0][0] * m2[0][2] + m_rows[0][1] * m2[1][2] + m_rows[0][2] * m2[2][2], + m_rows[1][0] * m2[0][0] + m_rows[1][1] * m2[1][0] + m_rows[1][2] * m2[2][0], + m_rows[1][0] * m2[0][1] + m_rows[1][1] * m2[1][1] + m_rows[1][2] * m2[2][1], + m_rows[1][0] * m2[0][2] + m_rows[1][1] * m2[1][2] + m_rows[1][2] * m2[2][2], + m_rows[2][0] * m2[0][0] + m_rows[2][1] * m2[1][0] + m_rows[2][2] * m2[2][0], + m_rows[2][0] * m2[0][1] + m_rows[2][1] * m2[1][1] + m_rows[2][2] * m2[2][1], + m_rows[2][0] * m2[0][2] + m_rows[2][1] * m2[1][2] + m_rows[2][2] * m2[2][2]); +} + +template <typename T> +template <typename T2> +auto Matrix3<T>::operator*(const Vector<T2, 3>& u) const -> Vec3 { + return Vec3(m_rows[0][0] * u[0] + m_rows[0][1] * u[1] + m_rows[0][2] * u[2], + m_rows[1][0] * u[0] + m_rows[1][1] * u[1] + m_rows[1][2] * u[2], + m_rows[2][0] * u[0] + m_rows[2][1] * u[1] + m_rows[2][2] * u[2]); +} + +template <typename T> +Matrix3<T> Matrix3<T>::operator/(const T& s) const { + return Matrix3<T>(m_rows[0] / s, m_rows[1] / s, m_rows[2] / s); +} + +template <typename T> +Matrix3<T> Matrix3<T>::operator*(const T& s) const { + return Matrix3<T>(m_rows[0] * s, m_rows[1] * s, m_rows[2] * s); +} + +template <typename T> +T determinant(const Matrix3<T>& m) { + return m.determinant(); +} + +template <typename T> +Matrix3<T> transpose(Matrix3<T> m) { + return m.transpose(); +} + +template <typename T> +Matrix3<T> ortho(Matrix3<T> mat) { + return mat.orthogonalize(); +} + +template <typename T> +Matrix3<T> operator*(T s, const Matrix3<T>& m) { + return m * s; +} + +template <typename T> +std::ostream& operator<<(std::ostream& os, Matrix3<T> m) { + os << m[0][0] << ' ' << m[0][1] << ' ' << m[0][2] << std::endl; + os << m[1][0] << ' ' << m[1][1] << ' ' << m[1][2] << std::endl; + os << m[2][0] << ' ' << m[2][1] << ' ' << m[2][2]; + return os; +} + +} + +#endif diff --git a/source/core/StarMaybe.hpp b/source/core/StarMaybe.hpp new file mode 100644 index 0000000..f0e7904 --- /dev/null +++ b/source/core/StarMaybe.hpp @@ -0,0 +1,400 @@ +#ifndef STAR_MAYBE_HPP +#define STAR_MAYBE_HPP + +#include "StarException.hpp" +#include "StarHash.hpp" + +namespace Star { + +STAR_EXCEPTION(InvalidMaybeAccessException, StarException); + +template <typename T> +class Maybe { +public: + typedef T* PointerType; + typedef T const* PointerConstType; + typedef T& RefType; + typedef T const& RefConstType; + + Maybe(); + + Maybe(T const& t); + Maybe(T&& t); + + Maybe(Maybe const& rhs); + Maybe(Maybe&& rhs); + template <typename T2> + Maybe(Maybe<T2> const& rhs); + + ~Maybe(); + + Maybe& operator=(Maybe const& rhs); + Maybe& operator=(Maybe&& rhs); + template <typename T2> + Maybe& operator=(Maybe<T2> const& rhs); + + bool isValid() const; + bool isNothing() const; + explicit operator bool() const; + + PointerConstType ptr() const; + PointerType ptr(); + + PointerConstType operator->() const; + PointerType operator->(); + + RefConstType operator*() const; + RefType operator*(); + + bool operator==(Maybe const& rhs) const; + bool operator!=(Maybe const& rhs) const; + bool operator<(Maybe const& rhs) const; + + RefConstType get() const; + RefType get(); + + // Get either the contents of this Maybe or the given default. + T value(T def = T()) const; + + // Get either this value, or if this value is none the given value. + Maybe orMaybe(Maybe const& other) const; + + // Takes the value out of this Maybe, leaving it Nothing. + T take(); + + // If this Maybe is set, assigns it to t and leaves this Maybe as Nothing. + bool put(T& t); + + void set(T const& t); + void set(T&& t); + + template <typename... Args> + void emplace(Args&&... t); + + void reset(); + + // Apply a function to the contained value if it is not Nothing. + template <typename Function> + void exec(Function&& function); + + // Functor map operator. If this maybe is not Nothing, then applies the + // given function to it and returns the result, otherwise returns Nothing (of + // the type the function would normally return). + template <typename Function> + auto apply(Function&& function) const -> Maybe<typename std::decay<decltype(function(std::declval<T>()))>::type>; + + // Monadic bind operator. Given function should return another Maybe. + template <typename Function> + auto sequence(Function function) const -> decltype(function(std::declval<T>())); + +private: + union { + T m_data; + }; + + bool m_initialized; +}; + +template <typename T> +std::ostream& operator<<(std::ostream& os, Maybe<T> const& v); + +template <typename T> +struct hash<Maybe<T>> { + size_t operator()(Maybe<T> const& m) const; + hash<T> hasher; +}; + +template <typename T> +Maybe<T>::Maybe() + : m_initialized(false) {} + +template <typename T> +Maybe<T>::Maybe(T const& t) + : Maybe() { + new (&m_data) T(t); + m_initialized = true; +} + +template <typename T> +Maybe<T>::Maybe(T&& t) + : Maybe() { + new (&m_data) T(forward<T>(t)); + m_initialized = true; +} + +template <typename T> +Maybe<T>::Maybe(Maybe const& rhs) + : Maybe() { + if (rhs.m_initialized) { + new (&m_data) T(rhs.m_data); + m_initialized = true; + } +} + +template <typename T> +Maybe<T>::Maybe(Maybe&& rhs) + : Maybe() { + if (rhs.m_initialized) { + new (&m_data) T(move(rhs.m_data)); + m_initialized = true; + rhs.reset(); + } +} + +template <typename T> +template <typename T2> +Maybe<T>::Maybe(Maybe<T2> const& rhs) + : Maybe() { + if (rhs) { + new (&m_data) T(*rhs); + m_initialized = true; + } +} + +template <typename T> +Maybe<T>::~Maybe() { + reset(); +} + +template <typename T> +Maybe<T>& Maybe<T>::operator=(Maybe const& rhs) { + if (&rhs == this) + return *this; + + if (rhs) + emplace(*rhs); + else + reset(); + + return *this; +} + +template <typename T> +template <typename T2> +Maybe<T>& Maybe<T>::operator=(Maybe<T2> const& rhs) { + if (rhs) + emplace(*rhs); + else + reset(); + + return *this; +} + +template <typename T> +Maybe<T>& Maybe<T>::operator=(Maybe&& rhs) { + if (&rhs == this) + return *this; + + if (rhs) + emplace(rhs.take()); + else + reset(); + + return *this; +} + +template <typename T> +bool Maybe<T>::isValid() const { + return m_initialized; +} + +template <typename T> +bool Maybe<T>::isNothing() const { + return !m_initialized; +} + +template <typename T> +Maybe<T>::operator bool() const { + return m_initialized; +} + +template <typename T> +auto Maybe<T>::ptr() const -> PointerConstType { + if (m_initialized) + return &m_data; + return nullptr; +} + +template <typename T> +auto Maybe<T>::ptr() -> PointerType { + if (m_initialized) + return &m_data; + return nullptr; +} + +template <typename T> +auto Maybe<T>::operator-> () const -> PointerConstType { + if (!m_initialized) + throw InvalidMaybeAccessException(); + + return &m_data; +} + +template <typename T> +auto Maybe<T>::operator->() -> PointerType { + if (!m_initialized) + throw InvalidMaybeAccessException(); + + return &m_data; +} + +template <typename T> +auto Maybe<T>::operator*() const -> RefConstType { + return get(); +} + +template <typename T> +auto Maybe<T>::operator*() -> RefType { + return get(); +} + +template <typename T> +bool Maybe<T>::operator==(Maybe const& rhs) const { + if (!m_initialized && !rhs.m_initialized) + return true; + if (m_initialized && rhs.m_initialized) + return get() == rhs.get(); + return false; +} + +template <typename T> +bool Maybe<T>::operator!=(Maybe const& rhs) const { + return !operator==(rhs); +} + +template <typename T> +bool Maybe<T>::operator<(Maybe const& rhs) const { + if (m_initialized && rhs.m_initialized) + return get() < rhs.get(); + if (!m_initialized && rhs.m_initialized) + return true; + return false; +} + +template <typename T> +auto Maybe<T>::get() const -> RefConstType { + if (!m_initialized) + throw InvalidMaybeAccessException(); + + return m_data; +} + +template <typename T> +auto Maybe<T>::get() -> RefType { + if (!m_initialized) + throw InvalidMaybeAccessException(); + + return m_data; +} + +template <typename T> +T Maybe<T>::value(T def) const { + if (m_initialized) + return m_data; + else + return def; +} + +template <typename T> +Maybe<T> Maybe<T>::orMaybe(Maybe const& other) const { + if (m_initialized) + return *this; + else + return other; +} + +template <typename T> +T Maybe<T>::take() { + if (!m_initialized) + throw InvalidMaybeAccessException(); + + T val(move(m_data)); + + reset(); + + return val; +} + +template <typename T> +bool Maybe<T>::put(T& t) { + if (m_initialized) { + t = move(m_data); + + reset(); + + return true; + } else { + return false; + } +} + +template <typename T> +void Maybe<T>::set(T const& t) { + emplace(t); +} + +template <typename T> +void Maybe<T>::set(T&& t) { + emplace(forward<T>(t)); +} + +template <typename T> +template <typename... Args> +void Maybe<T>::emplace(Args&&... t) { + reset(); + + new (&m_data) T(forward<Args>(t)...); + m_initialized = true; +} + +template <typename T> +void Maybe<T>::reset() { + if (m_initialized) { + m_initialized = false; + m_data.~T(); + } +} + +template <typename T> +template <typename Function> +auto Maybe<T>::apply(Function&& function) const + -> Maybe<typename std::decay<decltype(function(std::declval<T>()))>::type> { + if (!isValid()) + return {}; + return function(get()); +} + +template <typename T> +template <typename Function> +void Maybe<T>::exec(Function&& function) { + if (isValid()) + function(get()); +} + +template <typename T> +template <typename Function> +auto Maybe<T>::sequence(Function function) const -> decltype(function(std::declval<T>())) { + if (!isValid()) + return {}; + return function(get()); +} + +template <typename T> +std::ostream& operator<<(std::ostream& os, Maybe<T> const& v) { + if (v) + return os << "Just (" << *v << ")"; + else + return os << "Nothing"; +} + +template <typename T> +size_t hash<Maybe<T>>::operator()(Maybe<T> const& m) const { + if (!m) + return 0; + else + return hasher(*m); +} + +} + +#endif diff --git a/source/core/StarMemory.cpp b/source/core/StarMemory.cpp new file mode 100644 index 0000000..3855283 --- /dev/null +++ b/source/core/StarMemory.cpp @@ -0,0 +1,93 @@ +#include "StarMemory.hpp" + +#ifdef STAR_USE_JEMALLOC +#include "jemalloc/jemalloc.h" +#endif + +namespace Star { + +#ifdef STAR_USE_JEMALLOC + void* malloc(size_t size) { + return je_malloc(size); + } + + void* realloc(void* ptr, size_t size) { + return je_realloc(ptr, size); + } + + void free(void* ptr) { + je_free(ptr); + } + + void free(void* ptr, size_t size) { + if (ptr) + je_sdallocx(ptr, size, 0); + } +#else + void* malloc(size_t size) { + return ::malloc(size); + } + + void* realloc(void* ptr, size_t size) { + return ::realloc(ptr, size); + } + + void free(void* ptr) { + return ::free(ptr); + } + + void free(void* ptr, size_t) { + return ::free(ptr); + } +#endif + +} + +void* operator new(std::size_t size) { + auto ptr = Star::malloc(size); + if (!ptr) + throw std::bad_alloc(); + return ptr; +} + +void* operator new[](std::size_t size) { + auto ptr = Star::malloc(size); + if (!ptr) + throw std::bad_alloc(); + return ptr; +} + +// Globally override new and delete. As the per standard, new and delete must +// be defined in global scope, and must not be inline. + +void* operator new(std::size_t size, std::nothrow_t const&) noexcept { + return Star::malloc(size); +} + +void* operator new[](std::size_t size, std::nothrow_t const&) noexcept { + return Star::malloc(size); +} + +void operator delete(void* ptr) noexcept { + Star::free(ptr); +} + +void operator delete[](void* ptr) noexcept { + Star::free(ptr); +} + +void operator delete(void* ptr, std::nothrow_t const&) noexcept { + Star::free(ptr); +} + +void operator delete[](void* ptr, std::nothrow_t const&) noexcept { + Star::free(ptr); +} + +void operator delete(void* ptr, std::size_t size) noexcept { + Star::free(ptr, size); +} + +void operator delete[](void* ptr, std::size_t size) noexcept { + Star::free(ptr, size); +} diff --git a/source/core/StarMemory.hpp b/source/core/StarMemory.hpp new file mode 100644 index 0000000..a75ee20 --- /dev/null +++ b/source/core/StarMemory.hpp @@ -0,0 +1,20 @@ +#ifndef STAR_MEMORY_HPP +#define STAR_MEMORY_HPP + +#include <new> + +#include "StarConfig.hpp" + +namespace Star { + +// Don't want to override global C allocation functions, as our API is +// different. + +void* malloc(size_t size); +void* realloc(void* ptr, size_t size); +void free(void* ptr); +void free(void* ptr, size_t size); + +} + +#endif diff --git a/source/core/StarMultiArray.hpp b/source/core/StarMultiArray.hpp new file mode 100644 index 0000000..ffd432d --- /dev/null +++ b/source/core/StarMultiArray.hpp @@ -0,0 +1,510 @@ +#ifndef STAR_MULTI_ARRAY_HPP +#define STAR_MULTI_ARRAY_HPP + +#include "StarArray.hpp" +#include "StarList.hpp" + +namespace Star { + +STAR_EXCEPTION(MultiArrayException, StarException); + +// Multidimensional array class that wraps a vector as a simple contiguous N +// dimensional array. Values are stored so that the highest dimension is the +// dimension with stride 0, and the lowest dimension has the largest stride. +// +// Due to usage of std::vector, ElementT = bool means that the user must use +// set() and get() rather than operator() +template <typename ElementT, size_t RankN> +class MultiArray { +public: + typedef List<ElementT> Storage; + + typedef ElementT Element; + static size_t const Rank = RankN; + + typedef Array<size_t, Rank> IndexArray; + typedef Array<size_t, Rank> SizeArray; + + typedef typename Storage::iterator iterator; + typedef typename Storage::const_iterator const_iterator; + typedef Element value_type; + + MultiArray(); + template <typename... T> + explicit MultiArray(size_t i, T... rest); + explicit MultiArray(SizeArray const& shape); + explicit MultiArray(SizeArray const& shape, Element const& c); + + SizeArray const& size() const; + size_t size(size_t dimension) const; + + void clear(); + + void resize(SizeArray const& shape); + void resize(SizeArray const& shape, Element const& c); + + template <typename... T> + void resize(size_t i, T... rest); + + void fill(Element const& element); + + // Does not preserve previous element position, array contents will be + // invalid. + void setSize(SizeArray const& shape); + void setSize(SizeArray const& shape, Element const& c); + + template <typename... T> + void setSize(size_t i, T... rest); + + Element& operator()(IndexArray const& index); + Element const& operator()(IndexArray const& index) const; + + template <typename... T> + Element& operator()(size_t i1, T... rest); + template <typename... T> + Element const& operator()(size_t i1, T... rest) const; + + // Throws exception if out of bounds + Element& at(IndexArray const& index); + Element const& at(IndexArray const& index) const; + + template <typename... T> + Element& at(size_t i1, T... rest); + template <typename... T> + Element const& at(size_t i1, T... rest) const; + + // Throws an exception of out of bounds + void set(IndexArray const& index, Element element); + + // Returns default element if out of bounds. + Element get(IndexArray const& index, Element def = Element()); + + // Auto-resizes array if out of bounds + void setResize(IndexArray const& index, Element element); + + // Copy the given array element for element into this array. The shape of + // this array must be at least as large in every dimension as the source + // array + void copy(MultiArray const& source); + void copy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin); + + // op will be called with IndexArray and Element parameters. + template <typename OpType> + void forEach(IndexArray const& min, SizeArray const& size, OpType&& op); + template <typename OpType> + void forEach(IndexArray const& min, SizeArray const& size, OpType&& op) const; + + // Shortcut for calling forEach on the entire array + template <typename OpType> + void forEach(OpType&& op); + template <typename OpType> + void forEach(OpType&& op) const; + + template <typename OStream> + void print(OStream& os) const; + + // Api for more direct access to elements. + + size_t count() const; + + Element const& atIndex(size_t index) const; + Element& atIndex(size_t index); + + Element const* data() const; + Element* data(); + +private: + size_t storageIndex(IndexArray const& index) const; + + template <typename OStream> + void subPrint(OStream& os, IndexArray index, size_t dim) const; + + template <typename OpType> + void subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) const; + + template <typename OpType> + void subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim); + + void subCopy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, + IndexArray const& targetMin, IndexArray& sourceIndex, IndexArray& targetIndex, size_t dim); + + Storage m_data; + SizeArray m_shape; +}; + +typedef MultiArray<int, 2> MultiArray2I; +typedef MultiArray<size_t, 2> MultiArray2S; +typedef MultiArray<unsigned, 2> MultiArray2U; +typedef MultiArray<float, 2> MultiArray2F; +typedef MultiArray<double, 2> MultiArray2D; + +typedef MultiArray<int, 3> MultiArray3I; +typedef MultiArray<size_t, 3> MultiArray3S; +typedef MultiArray<unsigned, 3> MultiArray3U; +typedef MultiArray<float, 3> MultiArray3F; +typedef MultiArray<double, 3> MultiArray3D; + +typedef MultiArray<int, 4> MultiArray4I; +typedef MultiArray<size_t, 4> MultiArray4S; +typedef MultiArray<unsigned, 4> MultiArray4U; +typedef MultiArray<float, 4> MultiArray4F; +typedef MultiArray<double, 4> MultiArray4D; + +template <typename Element, size_t Rank> +std::ostream& operator<<(std::ostream& os, MultiArray<Element, Rank> const& array); + +template <typename Element, size_t Rank> +MultiArray<Element, Rank>::MultiArray() { + m_shape = SizeArray::filled(0); +} + +template <typename Element, size_t Rank> +MultiArray<Element, Rank>::MultiArray(SizeArray const& shape) { + setSize(shape); +} + +template <typename Element, size_t Rank> +MultiArray<Element, Rank>::MultiArray(SizeArray const& shape, Element const& c) { + setSize(shape, c); +} + +template <typename Element, size_t Rank> +template <typename... T> +MultiArray<Element, Rank>::MultiArray(size_t i, T... rest) { + setSize(SizeArray{i, rest...}); +} + +template <typename Element, size_t Rank> +typename MultiArray<Element, Rank>::SizeArray const& MultiArray<Element, Rank>::size() const { + return m_shape; +} + +template <typename Element, size_t Rank> +size_t MultiArray<Element, Rank>::size(size_t dimension) const { + return m_shape[dimension]; +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::clear() { + setSize(SizeArray::filled(0)); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::resize(SizeArray const& shape) { + if (m_data.empty()) { + setSize(shape); + return; + } + + bool equal = true; + for (size_t i = 0; i < Rank; ++i) + equal = equal && (m_shape[i] == shape[i]); + + if (equal) + return; + + MultiArray newArray(shape); + newArray.copy(*this); + std::swap(*this, newArray); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::resize(SizeArray const& shape, Element const& c) { + if (m_data.empty()) { + setSize(shape, c); + return; + } + + bool equal = true; + for (size_t i = 0; i < Rank; ++i) + equal = equal && (m_shape[i] == shape[i]); + + if (equal) + return; + + MultiArray newArray(shape, c); + newArray.copy(*this); + *this = std::move(newArray); +} + +template <typename Element, size_t Rank> +template <typename... T> +void MultiArray<Element, Rank>::resize(size_t i, T... rest) { + resize(SizeArray{i, rest...}); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::fill(Element const& element) { + std::fill(m_data.begin(), m_data.end(), element); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::setSize(SizeArray const& shape) { + size_t storageSize = 1; + for (size_t i = 0; i < Rank; ++i) { + m_shape[i] = shape[i]; + storageSize *= shape[i]; + } + + m_data.resize(storageSize); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::setSize(SizeArray const& shape, Element const& c) { + size_t storageSize = 1; + for (size_t i = 0; i < Rank; ++i) { + m_shape[i] = shape[i]; + storageSize *= shape[i]; + } + m_data.resize(storageSize, c); +} + +template <typename Element, size_t Rank> +template <typename... T> +void MultiArray<Element, Rank>::setSize(size_t i, T... rest) { + setSize({i, rest...}); +} + +template <typename Element, size_t Rank> +Element& MultiArray<Element, Rank>::operator()(IndexArray const& index) { + return m_data[storageIndex(index)]; +} + +template <typename Element, size_t Rank> +Element const& MultiArray<Element, Rank>::operator()(IndexArray const& index) const { + return m_data[storageIndex(index)]; +} + +template <typename Element, size_t Rank> +template <typename... T> +Element& MultiArray<Element, Rank>::operator()(size_t i1, T... rest) { + return m_data[storageIndex(IndexArray(i1, rest...))]; +} + +template <typename Element, size_t Rank> +template <typename... T> +Element const& MultiArray<Element, Rank>::operator()(size_t i1, T... rest) const { + return m_data[storageIndex(IndexArray(i1, rest...))]; +} + +template <typename Element, size_t Rank> +Element const& MultiArray<Element, Rank>::at(IndexArray const& index) const { + for (size_t i = Rank; i != 0; --i) { + if (index[i - 1] >= m_shape[i - 1]) + throw MultiArrayException(strf("Out of bounds on MultiArray::at(%s)", index)); + } + + return m_data[storageIndex(index)]; +} + +template <typename Element, size_t Rank> +Element& MultiArray<Element, Rank>::at(IndexArray const& index) { + for (size_t i = Rank; i != 0; --i) { + if (index[i - 1] >= m_shape[i - 1]) + throw MultiArrayException(strf("Out of bounds on MultiArray::at(%s)", index)); + } + + return m_data[storageIndex(index)]; +} + +template <typename Element, size_t Rank> +template <typename... T> +Element& MultiArray<Element, Rank>::at(size_t i1, T... rest) { + return at(IndexArray(i1, rest...)); +} + +template <typename Element, size_t Rank> +template <typename... T> +Element const& MultiArray<Element, Rank>::at(size_t i1, T... rest) const { + return at(IndexArray(i1, rest...)); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::set(IndexArray const& index, Element element) { + for (size_t i = Rank; i != 0; --i) { + if (index[i - 1] >= m_shape[i - 1]) + throw MultiArrayException(strf("Out of bounds on MultiArray::set(%s)", index)); + } + + m_data[storageIndex(index)] = move(element); +} + +template <typename Element, size_t Rank> +Element MultiArray<Element, Rank>::get(IndexArray const& index, Element def) { + for (size_t i = Rank; i != 0; --i) { + if (index[i - 1] >= m_shape[i - 1]) + return move(def); + } + + return m_data[storageIndex(index)]; +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::setResize(IndexArray const& index, Element element) { + SizeArray newShape; + for (size_t i = 0; i < Rank; ++i) + newShape[i] = std::max(m_shape[i], index[i] + 1); + resize(newShape); + + m_data[storageIndex(index)] = move(element); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::copy(MultiArray const& source) { + IndexArray max; + for (size_t i = 0; i < Rank; ++i) + max[i] = std::min(size(i), source.size(i)); + + copy(source, IndexArray::filled(0), max, IndexArray::filled(0)); +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::copy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin) { + IndexArray sourceIndex; + IndexArray targetIndex; + subCopy(source, sourceMin, sourceMax, targetMin, sourceIndex, targetIndex, 0); +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::forEach(IndexArray const& min, SizeArray const& size, OpType&& op) { + IndexArray index; + subForEach(min, size, forward<OpType>(op), index, 0, 0); +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::forEach(IndexArray const& min, SizeArray const& size, OpType&& op) const { + IndexArray index; + subForEach(min, size, forward<OpType>(op), index, 0, 0); +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::forEach(OpType&& op) { + forEach(IndexArray::filled(0), size(), forward<OpType>(op)); +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::forEach(OpType&& op) const { + forEach(IndexArray::filled(0), size(), forward<OpType>(op)); +} + +template <typename Element, size_t Rank> +template <typename OStream> +void MultiArray<Element, Rank>::print(OStream& os) const { + subPrint(os, IndexArray(), 0); +} + +template <typename Element, size_t Rank> +size_t MultiArray<Element, Rank>::count() const { + return m_data.size(); +} + +template <typename Element, size_t Rank> +Element const& MultiArray<Element, Rank>::atIndex(size_t index) const { + return m_data[index]; +} + +template <typename Element, size_t Rank> +Element& MultiArray<Element, Rank>::atIndex(size_t index) { + return m_data[index]; +} + +template <typename Element, size_t Rank> +Element const* MultiArray<Element, Rank>::data() const { + return m_data.ptr(); +} + +template <typename Element, size_t Rank> +Element* MultiArray<Element, Rank>::data() { + return m_data.ptr(); +} + +template <typename Element, size_t Rank> +size_t MultiArray<Element, Rank>::storageIndex(IndexArray const& index) const { + size_t loc = index[0]; + starAssert(index[0] < m_shape[0]); + for (size_t i = 1; i < Rank; ++i) { + loc = loc * m_shape[i] + index[i]; + starAssert(index[i] < m_shape[i]); + } + return loc; +} + +template <typename Element, size_t Rank> +template <typename OStream> +void MultiArray<Element, Rank>::subPrint(OStream& os, IndexArray index, size_t dim) const { + if (dim == Rank - 1) { + for (size_t i = 0; i < m_shape[dim]; ++i) { + index[dim] = i; + os << m_data[storageIndex(index)] << ' '; + } + os << std::endl; + } else { + for (size_t i = 0; i < m_shape[dim]; ++i) { + index[dim] = i; + subPrint(os, index, dim + 1); + } + os << std::endl; + } +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) { + size_t minIndex = min[dim]; + size_t maxIndex = minIndex + size[dim]; + for (size_t i = minIndex; i < maxIndex; ++i) { + index[dim] = i; + if (dim == Rank - 1) + op(index, m_data[offset + i]); + else + subForEach(min, size, forward<OpType>(op), index, (offset + i) * m_shape[dim + 1], dim + 1); + } +} + +template <typename Element, size_t Rank> +template <typename OpType> +void MultiArray<Element, Rank>::subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) const { + size_t minIndex = min[dim]; + size_t maxIndex = minIndex + size[dim]; + for (size_t i = minIndex; i < maxIndex; ++i) { + index[dim] = i; + if (dim == Rank - 1) + op(index, m_data[offset + i]); + else + subForEach(min, size, forward<OpType>(op), index, (offset + i) * m_shape[dim + 1], dim + 1); + } +} + +template <typename Element, size_t Rank> +void MultiArray<Element, Rank>::subCopy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, + IndexArray const& targetMin, IndexArray& sourceIndex, IndexArray& targetIndex, size_t dim) { + size_t w = sourceMax[dim] - sourceMin[dim]; + if (dim < Rank - 1) { + for (size_t i = 0; i < w; ++i) { + sourceIndex[dim] = i + sourceMin[dim]; + targetIndex[dim] = i + targetMin[dim]; + subCopy(source, sourceMin, sourceMax, targetMin, sourceIndex, targetIndex, dim + 1); + } + } else { + sourceIndex[dim] = sourceMin[dim]; + targetIndex[dim] = targetMin[dim]; + size_t sourceStorageStart = source.storageIndex(sourceIndex); + size_t targetStorageStart = storageIndex(targetIndex); + for (size_t i = 0; i < w; ++i) + m_data[targetStorageStart + i] = source.m_data[sourceStorageStart + i]; + } +} + +template <typename Element, size_t Rank> +std::ostream& operator<<(std::ostream& os, MultiArray<Element, Rank> const& array) { + array.print(os); + return os; +} + +} + +#endif diff --git a/source/core/StarMultiArrayInterpolator.hpp b/source/core/StarMultiArrayInterpolator.hpp new file mode 100644 index 0000000..40a708d --- /dev/null +++ b/source/core/StarMultiArrayInterpolator.hpp @@ -0,0 +1,539 @@ +#ifndef STAR_MULTI_ARRAY_INTERPOLATOR_HPP +#define STAR_MULTI_ARRAY_INTERPOLATOR_HPP + +#include "StarMultiArray.hpp" +#include "StarInterpolation.hpp" + +namespace Star { + +template <typename MultiArrayT, typename PositionT> +struct MultiArrayInterpolator2 { + typedef MultiArrayT MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = MultiArray::Rank; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 2> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator2(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList imin; + IndexList imax; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto binfo = getBound2(coord[i], array.size(i), boundMode); + imin[i] = binfo.i0; + imax[i] = binfo.i1; + offset[i] = binfo.offset; + } + + return interpolateSub(array, imin, imax, offset, IndexList(), 0); + } + + Element interpolateSub( + MultiArray const& array, + IndexList const& imin, IndexList const& imax, + PositionList const& offset, IndexList const& index, + size_t const dim) const { + IndexList minIndex = index; + IndexList maxIndex = index; + + minIndex[dim] = imin[dim]; + maxIndex[dim] = imax[dim]; + + WeightList weights = weightFunction(offset[dim]); + + if (dim == Rank - 1) { + return weights[0] * array(minIndex) + weights[1] * array(maxIndex); + } else { + return + weights[0] * interpolateSub(array, imin, imax, offset, minIndex, dim+1) + + weights[1] * interpolateSub(array, imin, imax, offset, maxIndex, dim+1); + } + } +}; + +template <typename MultiArrayT, typename PositionT> +struct MultiArrayInterpolator4 { + typedef MultiArrayT MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = MultiArray::Rank; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 4> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator4(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList index0; + IndexList index1; + IndexList index2; + IndexList index3; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto bound = getBound4(coord[i], array.size(i), boundMode); + index0[i] = bound.i0; + index1[i] = bound.i1; + index2[i] = bound.i2; + index3[i] = bound.i3; + offset[i] = bound.offset; + } + + return interpolateSub(array, index0, index1, index2, index3, offset, IndexList(), 0); + } + + Element interpolateSub( + MultiArray const& array, + IndexList const& i0, IndexList const& i1, + IndexList const& i2, IndexList const& i3, + PositionList const& offset, IndexList const& index, + size_t const dim) const { + IndexList index0 = index; + IndexList index1 = index; + IndexList index2 = index; + IndexList index3 = index; + + index0[dim] = i0[dim]; + index1[dim] = i1[dim]; + index2[dim] = i2[dim]; + index3[dim] = i3[dim]; + + WeightList weights = weightFunction(offset[dim]); + + if (dim == Rank - 1) { + return + weights[0] * array(index0) + + weights[1] * array(index1) + + weights[2] * array(index2) + + weights[3] * array(index3); + } else { + return + weights[0] * interpolateSub(array, i0, i1, i2, i3, offset, index0, dim+1) + + weights[1] * interpolateSub(array, i0, i1, i2, i3, offset, index1, dim+1) + + weights[2] * interpolateSub(array, i0, i1, i2, i3, offset, index2, dim+1) + + weights[3] * interpolateSub(array, i0, i1, i2, i3, offset, index3, dim+1); + } + } +}; + +template <typename MultiArrayT, typename PositionT> +struct MultiArrayPiecewiseInterpolator { + typedef MultiArrayT MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = MultiArray::Rank; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 2> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + struct PiecewiseRange { + size_t dim; + Position offset; + + bool operator<(PiecewiseRange const& pr) const { + return pr.offset < offset; + } + }; + typedef Array<PiecewiseRange, Rank> PiecewiseRangeList; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayPiecewiseInterpolator(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + // O(n) for n-dimensions. + Element interpolate(MultiArray const& array, PositionList const& coord) const { + PiecewiseRangeList piecewiseRangeList; + + IndexList minIndex; + IndexList maxIndex; + + for (size_t i = 0; i < Rank; ++i) { + PiecewiseRange range; + range.dim = i; + + auto bound = getBound2(coord[i], array.size(i), boundMode); + minIndex[i] = bound.i0; + maxIndex[i] = bound.i1; + range.offset = bound.offset; + + piecewiseRangeList[i] = range; + } + + std::sort(piecewiseRangeList.begin(), piecewiseRangeList.end()); + + IndexList location = minIndex; + Element result = array(location); + Element last = result; + Element current; + + for (size_t i = 0; i < Rank; ++i) { + auto const& pr = piecewiseRangeList[i]; + location[pr.dim] = maxIndex[pr.dim]; + current = array(location); + + WeightList weights = weightFunction(pr.offset); + result += last * (weights[0] - 1) + current * weights[1]; + last = current; + } + + return result; + } +}; + +// Template specializations for Rank 2 + +template <typename ElementT, typename PositionT> +struct MultiArrayInterpolator2<MultiArray<ElementT, 2>, PositionT> { + typedef Star::MultiArray<ElementT, 2> MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = 2; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 2> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator2(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList imin; + IndexList imax; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto bound = getBound2(coord[i], array.size(i), boundMode); + imin[i] = bound.i0; + imax[i] = bound.i1; + offset[i] = bound.offset; + } + + WeightList xweights = weightFunction(offset[0]); + WeightList yweights = weightFunction(offset[1]); + + return + xweights[0] * (yweights[0] * array(imin[0], imin[1]) + yweights[1] * array(imin[0], imax[1])) + + xweights[1] * (yweights[0] * array(imax[0], imin[1]) + yweights[1] * array(imax[0], imax[1])); + } +}; + +template <typename ElementT, typename PositionT> +struct MultiArrayInterpolator4<MultiArray<ElementT, 2>, PositionT> { + typedef Star::MultiArray<ElementT, 2> MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = 2; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 4> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator4(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList index0; + IndexList index1; + IndexList index2; + IndexList index3; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto bound = getBound4(coord[i], array.size(i), boundMode); + index0[i] = bound.i0; + index1[i] = bound.i1; + index2[i] = bound.i2; + index3[i] = bound.i3; + offset[i] = bound.offset; + } + + WeightList xweights = weightFunction(offset[0]); + WeightList yweights = weightFunction(offset[1]); + + return + xweights[0] * ( + yweights[0] * array(index0[0], index0[1]) + + yweights[1] * array(index0[0], index1[1]) + + yweights[2] * array(index0[0], index2[1]) + + yweights[3] * array(index0[0], index3[1]) + ) + + xweights[1] * ( + yweights[0] * array(index1[0], index0[1]) + + yweights[1] * array(index1[0], index1[1]) + + yweights[2] * array(index1[0], index2[1]) + + yweights[3] * array(index1[0], index3[1]) + ) + + xweights[2] * ( + yweights[0] * array(index2[0], index0[1]) + + yweights[1] * array(index2[0], index1[1]) + + yweights[2] * array(index2[0], index2[1]) + + yweights[3] * array(index2[0], index3[1]) + ) + + xweights[3] * ( + yweights[0] * array(index3[0], index0[1]) + + yweights[1] * array(index3[0], index1[1]) + + yweights[2] * array(index3[0], index2[1]) + + yweights[3] * array(index3[0], index3[1]) + ); + } +}; + +// Template specializations for Rank 3 + +template <typename ElementT, typename PositionT> +struct MultiArrayInterpolator2<MultiArray<ElementT, 3>, PositionT> { + typedef Star::MultiArray<ElementT, 3> MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = 3; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 2> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator2(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList imin; + IndexList imax; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto bound = getBound2(coord[i], array.size(i), boundMode); + imin[i] = bound.i0; + imax[i] = bound.i1; + offset[i] = bound.offset; + } + + WeightList xweights = weightFunction(offset[0]); + WeightList yweights = weightFunction(offset[1]); + WeightList zweights = weightFunction(offset[2]); + + return + xweights[0] * ( + yweights[0] * ( + zweights[0] * array(imin[0], imin[1], imin[2]) + + zweights[1] * array(imin[0], imin[1], imax[2]) + ) + + yweights[1] * ( + zweights[0] * array(imin[0], imax[1], imin[2]) + + zweights[1] * array(imin[0], imax[1], imax[2]) + ) + ) + + xweights[1] * ( + yweights[0] * ( + zweights[0] * array(imax[0], imin[1], imin[2]) + + zweights[1] * array(imax[0], imin[1], imax[2]) + ) + + yweights[1] * ( + zweights[0] * array(imax[0], imax[1], imin[2]) + + zweights[1] * array(imax[0], imax[1], imax[2]) + ) + ); + } +}; + +template <typename ElementT, typename PositionT> +struct MultiArrayInterpolator4<MultiArray<ElementT, 3>, PositionT> { + typedef Star::MultiArray<ElementT, 3> MultiArray; + typedef PositionT Position; + + typedef typename MultiArray::Element Element; + static size_t const Rank = 3; + + typedef Array<size_t, Rank> IndexList; + typedef Array<size_t, Rank> SizeList; + typedef Array<Position, Rank> PositionList; + typedef Array<Position, 4> WeightList; + + typedef std::function<WeightList(Position)> WeightFunction; + + WeightFunction weightFunction; + BoundMode boundMode; + + MultiArrayInterpolator4(WeightFunction wf, BoundMode b = BoundMode::Clamp) + : weightFunction(wf), boundMode(b) {} + + Element interpolate(MultiArray const& array, PositionList const& coord) const { + IndexList index0; + IndexList index1; + IndexList index2; + IndexList index3; + PositionList offset; + + for (size_t i = 0; i < Rank; ++i) { + auto bound = getBound4(coord[i], array.size(i), boundMode); + index0[i] = bound.i0; + index1[i] = bound.i1; + index2[i] = bound.i2; + index3[i] = bound.i3; + offset[i] = bound.offset; + } + + WeightList xweights = weightFunction(offset[0]); + WeightList yweights = weightFunction(offset[1]); + WeightList zweights = weightFunction(offset[2]); + + return + xweights[0] * ( + yweights[0] * ( + zweights[0] * array(index0[0], index0[1], index0[2]) + + zweights[1] * array(index0[0], index0[1], index1[2]) + + zweights[2] * array(index0[0], index0[1], index2[2]) + + zweights[3] * array(index0[0], index0[1], index3[2]) + ) + + yweights[1] * ( + zweights[0] * array(index0[0], index1[1], index0[2]) + + zweights[1] * array(index0[0], index1[1], index1[2]) + + zweights[2] * array(index0[0], index1[1], index2[2]) + + zweights[3] * array(index0[0], index1[1], index3[2]) + ) + + yweights[2] * ( + zweights[0] * array(index0[0], index2[1], index0[2]) + + zweights[1] * array(index0[0], index2[1], index1[2]) + + zweights[2] * array(index0[0], index2[1], index2[2]) + + zweights[3] * array(index0[0], index2[1], index3[2]) + ) + + yweights[3] * ( + zweights[0] * array(index0[0], index3[1], index0[2]) + + zweights[1] * array(index0[0], index3[1], index1[2]) + + zweights[2] * array(index0[0], index3[1], index2[2]) + + zweights[3] * array(index0[0], index3[1], index3[2]) + ) + ) + + xweights[1] * ( + yweights[0] * ( + zweights[0] * array(index1[0], index0[1], index0[2]) + + zweights[1] * array(index1[0], index0[1], index1[2]) + + zweights[2] * array(index1[0], index0[1], index2[2]) + + zweights[3] * array(index1[0], index0[1], index3[2]) + ) + + yweights[1] * ( + zweights[0] * array(index1[0], index1[1], index0[2]) + + zweights[1] * array(index1[0], index1[1], index1[2]) + + zweights[2] * array(index1[0], index1[1], index2[2]) + + zweights[3] * array(index1[0], index1[1], index3[2]) + ) + + yweights[2] * ( + zweights[0] * array(index1[0], index2[1], index0[2]) + + zweights[1] * array(index1[0], index2[1], index1[2]) + + zweights[2] * array(index1[0], index2[1], index2[2]) + + zweights[3] * array(index1[0], index2[1], index3[2]) + ) + + yweights[3] * ( + zweights[0] * array(index1[0], index3[1], index0[2]) + + zweights[1] * array(index1[0], index3[1], index1[2]) + + zweights[2] * array(index1[0], index3[1], index2[2]) + + zweights[3] * array(index1[0], index3[1], index3[2]) + ) + ) + + xweights[2] * ( + yweights[0] * ( + zweights[0] * array(index2[0], index0[1], index0[2]) + + zweights[1] * array(index2[0], index0[1], index1[2]) + + zweights[2] * array(index2[0], index0[1], index2[2]) + + zweights[3] * array(index2[0], index0[1], index3[2]) + ) + + yweights[1] * ( + zweights[0] * array(index2[0], index1[1], index0[2]) + + zweights[1] * array(index2[0], index1[1], index1[2]) + + zweights[2] * array(index2[0], index1[1], index2[2]) + + zweights[3] * array(index2[0], index1[1], index3[2]) + ) + + yweights[2] * ( + zweights[0] * array(index2[0], index2[1], index0[2]) + + zweights[1] * array(index2[0], index2[1], index1[2]) + + zweights[2] * array(index2[0], index2[1], index2[2]) + + zweights[3] * array(index2[0], index2[1], index3[2]) + ) + + yweights[3] * ( + zweights[0] * array(index2[0], index3[1], index0[2]) + + zweights[1] * array(index2[0], index3[1], index1[2]) + + zweights[2] * array(index2[0], index3[1], index2[2]) + + zweights[3] * array(index2[0], index3[1], index3[2]) + ) + ) + + xweights[3] * ( + yweights[0] * ( + zweights[0] * array(index3[0], index0[1], index0[2]) + + zweights[1] * array(index3[0], index0[1], index1[2]) + + zweights[2] * array(index3[0], index0[1], index2[2]) + + zweights[3] * array(index3[0], index0[1], index3[2]) + ) + + yweights[1] * ( + zweights[0] * array(index3[0], index1[1], index0[2]) + + zweights[1] * array(index3[0], index1[1], index1[2]) + + zweights[2] * array(index3[0], index1[1], index2[2]) + + zweights[3] * array(index3[0], index1[1], index3[2]) + ) + + yweights[2] * ( + zweights[0] * array(index3[0], index2[1], index0[2]) + + zweights[1] * array(index3[0], index2[1], index1[2]) + + zweights[2] * array(index3[0], index2[1], index2[2]) + + zweights[3] * array(index3[0], index2[1], index3[2]) + ) + + yweights[3] * ( + zweights[0] * array(index3[0], index3[1], index0[2]) + + zweights[1] * array(index3[0], index3[1], index1[2]) + + zweights[2] * array(index3[0], index3[1], index2[2]) + + zweights[3] * array(index3[0], index3[1], index3[2]) + ) + ); + } +}; + +} + +#endif diff --git a/source/core/StarMultiTable.hpp b/source/core/StarMultiTable.hpp new file mode 100644 index 0000000..3a44634 --- /dev/null +++ b/source/core/StarMultiTable.hpp @@ -0,0 +1,169 @@ +#ifndef STAR_MULTI_TABLE_HPP +#define STAR_MULTI_TABLE_HPP + +#include "StarMultiArrayInterpolator.hpp" + +namespace Star { + +// Provides a method for storing, retrieving, and interpolating uneven +// n-variate data. Access times involve a binary search over the domain of +// each dimension, so is O(log(n)*m) where n is the size of the largest +// dimension, and m is the table_rank. +template <typename ElementT, typename PositionT, size_t RankN> +class MultiTable { +public: + typedef ElementT Element; + typedef PositionT Position; + static size_t const Rank = RankN; + + typedef Star::MultiArray<ElementT, RankN> MultiArray; + + typedef Star::MultiArrayInterpolator2<MultiArray, Position> Interpolator2; + typedef Star::MultiArrayInterpolator4<MultiArray, Position> Interpolator4; + typedef Star::MultiArrayPiecewiseInterpolator<MultiArray, Position> PiecewiseInterpolator; + + typedef Array<Position, Rank> PositionArray; + typedef Array<Position, 2> WeightArray2; + typedef Array<Position, 4> WeightArray4; + typedef typename MultiArray::SizeArray SizeArray; + typedef typename MultiArray::IndexArray IndexArray; + typedef List<Position> Range; + typedef Array<Range, Rank> RangeArray; + + typedef std::function<WeightArray2(Position)> WeightFunction2; + typedef std::function<WeightArray4(Position)> WeightFunction4; + typedef std::function<Element(PositionArray const&)> InterpolateFunction; + + MultiTable() : m_interpolationMode(InterpolationMode::Linear), m_boundMode(BoundMode::Clamp) {} + + // Set input ranges on a particular dimension. Will resize underlying storage + // to fit range. + void setRange(std::size_t dim, Range const& range) { + SizeArray sizes = m_array.size(); + sizes[dim] = range.size(); + m_array.resize(sizes); + + m_ranges[dim] = range; + } + + void setRanges(RangeArray const& ranges) { + SizeArray arraySize; + + for (size_t dim = 0; dim < Rank; ++dim) { + arraySize[dim] = ranges[dim].size(); + m_ranges[dim] = ranges[dim]; + } + + m_array.resize(arraySize); + } + + // Set array element based on index. + void set(IndexArray const& index, Element const& element) { + m_array.set(index, element); + } + + // Get array element based on index. + Element const& get(IndexArray const& index) const { + return m_array(index); + } + + MultiArray const& array() const { + return m_array; + } + + MultiArray& array() { + return m_array; + } + + InterpolationMode interpolationMode() const { + return m_interpolationMode; + } + + void setInterpolationMode(InterpolationMode interpolationMode) { + m_interpolationMode = interpolationMode; + } + + BoundMode boundMode() const { + return m_boundMode; + } + + void setBoundMode(BoundMode boundMode) { + m_boundMode = boundMode; + } + + Element interpolate(PositionArray const& coord) const { + if (m_interpolationMode == InterpolationMode::HalfStep) { + PiecewiseInterpolator piecewiseInterpolator(StepWeightOperator<Position>(), m_boundMode); + return piecewiseInterpolator.interpolate(m_array, toIndexSpace(coord)); + + } else if (m_interpolationMode == InterpolationMode::Linear) { + Interpolator2 interpolator2(LinearWeightOperator<Position>(), m_boundMode); + return interpolator2.interpolate(m_array, toIndexSpace(coord)); + + } else if (m_interpolationMode == InterpolationMode::Cubic) { + // MultiTable uses CubicWeights with linear extrapolation (not + // configurable atm) + Interpolator4 interpolator4(Cubic4WeightOperator<Position>(true), m_boundMode); + return interpolator4.interpolate(m_array, toIndexSpace(coord)); + + } else { + throw MathException("Unsupported interpolation type in MultiTable::interpolate"); + } + } + + // Synonym for inteprolate + Element operator()(PositionArray const& coord) const { + return interpolate(coord); + } + + // op should take a PositionArray parameter and return an element. + template <typename OpType> + void eval(OpType op) { + m_array.forEach(EvalWrapper<OpType>(op, *this)); + } + +private: + template <typename Coordinate> + inline PositionArray toIndexSpace(Coordinate const& coord) const { + PositionArray indexCoord; + for (size_t i = 0; i < Rank; ++i) + indexCoord[i] = inverseLinearInterpolateLower(m_ranges[i].begin(), m_ranges[i].end(), coord[i]); + return indexCoord; + } + + template <typename OpType> + struct EvalWrapper { + EvalWrapper(OpType& o, MultiTable const& t) + : op(o), table(t) {} + + template <typename IndexArray> + void operator()(IndexArray const& indexArray, Element& element) { + PositionArray rangeArray; + for (size_t i = 0; i < Rank; ++i) + rangeArray[i] = table.m_ranges[i][indexArray[i]]; + + element = op(rangeArray); + } + + OpType& op; + MultiTable const& table; + }; + + RangeArray m_ranges; + MultiArray m_array; + InterpolationMode m_interpolationMode; + BoundMode m_boundMode; +}; + +typedef MultiTable<float, float, 2> MultiTable2F; +typedef MultiTable<double, double, 2> MultiTable2D; + +typedef MultiTable<float, float, 3> MultiTable3F; +typedef MultiTable<double, double, 3> MultiTable3D; + +typedef MultiTable<float, float, 4> MultiTable4F; +typedef MultiTable<double, double, 4> MultiTable4D; + +} + +#endif diff --git a/source/core/StarNetElement.cpp b/source/core/StarNetElement.cpp new file mode 100644 index 0000000..914badd --- /dev/null +++ b/source/core/StarNetElement.cpp @@ -0,0 +1,22 @@ +#include "StarNetElement.hpp" + + +namespace Star { + +uint64_t NetElementVersion::current() const { + return m_version; +} + +void NetElementVersion::increment() { + ++m_version; +} + +void NetElement::enableNetInterpolation(float) {} + +void NetElement::disableNetInterpolation() {} + +void NetElement::tickNetInterpolation(float) {} + +void NetElement::blankNetDelta(float) {} + +} diff --git a/source/core/StarNetElement.hpp b/source/core/StarNetElement.hpp new file mode 100644 index 0000000..5d21fe3 --- /dev/null +++ b/source/core/StarNetElement.hpp @@ -0,0 +1,62 @@ +#ifndef STAR_NET_ELEMENT_HPP +#define STAR_NET_ELEMENT_HPP + +#include "StarDataStream.hpp" + +namespace Star { + +// Monotonically increasing NetElementVersion shared between all NetElements in +// a network. +class NetElementVersion { +public: + uint64_t current() const; + void increment(); + +private: + uint64_t m_version = 0; +}; + +// Primary interface for the composable network synchronizable element system. +class NetElement { +public: + virtual ~NetElement() = default; + + // A network of NetElements will have a shared monotinically increasing + // NetElementVersion. When elements are updated, they will mark the version + // number at the time they are updated so that a delta can be constructed + // that contains only changes since any past version. + virtual void initNetVersion(NetElementVersion const* version = nullptr) = 0; + + // Full store / load of the entire element. + virtual void netStore(DataStream& ds) const = 0; + virtual void netLoad(DataStream& ds) = 0; + + // Enables interpolation mode. If interpolation mode is enabled, then + // NetElements will delay presenting incoming delta data for the + // 'interpolationTime' parameter given in readNetDelta, and smooth between + // received values. When interpolation is enabled, tickNetInterpolation must + // be periodically called to smooth values forward in time. If + // extrapolationHint is given, this may be used as a hint for the amount of + // time to extrapolate forward if no deltas are received. + virtual void enableNetInterpolation(float extrapolationHint = 0.0f); + virtual void disableNetInterpolation(); + virtual void tickNetInterpolation(float dt); + + // Write all the state changes that have happened since (and including) + // fromVersion. The normal way to use this would be to call writeDelta with + // the version at the time of the *last* call to writeDelta, + 1. If + // fromVersion is 0, this will always write the full state. Should return + // true if a delta was needed and was written to DataStream, false otherwise. + virtual bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const = 0; + // Read a delta written by writeNetDelta. 'interpolationTime' is the time in + // the future that data from this delta should be delayed and smoothed into, + // if interpolation is enabled. + virtual void readNetDelta(DataStream& ds, float interpolationTime = 0.0) = 0; + // When extrapolating, it is important to notify when a delta WOULD have been + // received even if no deltas are produced, so no extrapolation takes place. + virtual void blankNetDelta(float interpolationTime); +}; + +} + +#endif diff --git a/source/core/StarNetElementBasicFields.cpp b/source/core/StarNetElementBasicFields.cpp new file mode 100644 index 0000000..bbb5c5b --- /dev/null +++ b/source/core/StarNetElementBasicFields.cpp @@ -0,0 +1,64 @@ +#include "StarNetElementBasicFields.hpp" + +namespace Star { + +void NetElementSize::readData(DataStream& ds, size_t& v) const { + uint64_t s = ds.readVlqU(); + if (s == 0) + v = NPos; + else + v = s - 1; +} + +void NetElementSize::writeData(DataStream& ds, size_t const& v) const { + if (v == NPos) + ds.writeVlqU(0); + else + ds.writeVlqU(v + 1); +} + +void NetElementBool::readData(DataStream& ds, bool& v) const { + ds.read(v); +} + +void NetElementBool::writeData(DataStream& ds, bool const& v) const { + ds.write(v); +} + +void NetElementEvent::trigger() { + set(get() + 1); +} + +uint64_t NetElementEvent::pullOccurrences() { + uint64_t occurrences = get(); + starAssert(occurrences >= m_pulledOccurrences); + uint64_t unchecked = occurrences - m_pulledOccurrences; + m_pulledOccurrences = occurrences; + return unchecked; +} + +bool NetElementEvent::pullOccurred() { + return pullOccurrences() != 0; +} + +void NetElementEvent::ignoreOccurrences() { + m_pulledOccurrences = get(); +} + +void NetElementEvent::setIgnoreOccurrencesOnNetLoad(bool ignoreOccurrencesOnNetLoad) { + m_ignoreOccurrencesOnNetLoad = ignoreOccurrencesOnNetLoad; +} + +void NetElementEvent::netLoad(DataStream& ds) { + NetElementUInt::netLoad(ds); + if (m_ignoreOccurrencesOnNetLoad) + ignoreOccurrences(); +} + +void NetElementEvent::updated() { + NetElementBasicField::updated(); + if (m_pulledOccurrences > get()) + m_pulledOccurrences = get(); +} + +} diff --git a/source/core/StarNetElementBasicFields.hpp b/source/core/StarNetElementBasicFields.hpp new file mode 100644 index 0000000..f0939a7 --- /dev/null +++ b/source/core/StarNetElementBasicFields.hpp @@ -0,0 +1,331 @@ +#ifndef STAR_NET_STEP_STATES_HPP +#define STAR_NET_STEP_STATES_HPP + +#include <type_traits> + +#include "StarNetElement.hpp" +#include "StarString.hpp" +#include "StarByteArray.hpp" + +namespace Star { + +template <typename T> +class NetElementBasicField : public NetElement { +public: + virtual ~NetElementBasicField() = default; + + T const& get() const; + + // Updates the value if the value is different than the existing value, + // requires T have operator== + void set(T const& value); + + // Always updates the value and marks it as updated. + void push(T value); + + // Has this field been updated since the last call to pullUpdated? + bool pullUpdated(); + + // Update the value in place. The mutator will be called as bool + // mutator(T&), return true to signal that the value was updated. + template <typename Mutator> + void update(Mutator&& mutator); + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + // Values are never interpolated, but they will be delayed for the given + // interpolationTime. + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + +protected: + virtual void readData(DataStream& ds, T& t) const = 0; + virtual void writeData(DataStream& ds, T const& t) const = 0; + + virtual void updated(); + +private: + NetElementVersion const* m_netVersion = nullptr; + uint64_t m_latestUpdateVersion = 0; + T m_value = T(); + bool m_updated = false; + Maybe<Deque<pair<float, T>>> m_pendingInterpolatedValues; +}; + +template <typename T> +class NetElementIntegral : public NetElementBasicField<T> { +protected: + void readData(DataStream& ds, T& v) const override; + void writeData(DataStream& ds, T const& v) const override; +}; + +typedef NetElementIntegral<int64_t> NetElementInt; +typedef NetElementIntegral<uint64_t> NetElementUInt; + +// Properly encodes NPos no matter the platform width of size_t NetElement +// size_t values are NOT clamped when setting. +class NetElementSize : public NetElementBasicField<size_t> { +protected: + void readData(DataStream& ds, size_t& v) const override; + void writeData(DataStream& ds, size_t const& v) const override; +}; + +class NetElementBool : public NetElementBasicField<bool> { +protected: + void readData(DataStream& ds, bool& v) const override; + void writeData(DataStream& ds, bool const& v) const override; +}; + +template <typename Enum> +class NetElementEnum : public NetElementBasicField<Enum> { +protected: + void readData(DataStream& ds, Enum& v) const override; + void writeData(DataStream& ds, Enum const& v) const override; +}; + +// Wraps a uint64_t to give a simple event stream. Every trigger is an +// increment to a held uint64_t value, and slaves can see how many triggers +// have occurred since the last check. +class NetElementEvent : public NetElementUInt { +public: + void trigger(); + + // Returns the number of times this event has been triggered since the last + // pullOccurrences call. + uint64_t pullOccurrences(); + + // Pulls whether this event occurred at all, ignoring the number + bool pullOccurred(); + + // Ignore all the existing ocurrences + void ignoreOccurrences(); + void setIgnoreOccurrencesOnNetLoad(bool ignoreOccurrencesOnNetLoad); + + void netLoad(DataStream& ds) override; + +protected: + void updated() override; + +private: + using NetElementUInt::get; + using NetElementUInt::set; + using NetElementUInt::push; + using NetElementUInt::update; + + uint64_t m_pulledOccurrences = 0; + bool m_ignoreOccurrencesOnNetLoad = false; +}; + +// Holds an arbitrary serializable value +template <typename T> +class NetElementData : public NetElementBasicField<T> { +public: + NetElementData(); + NetElementData(function<void(DataStream&, T&)> reader, function<void(DataStream&, T const&)> writer); + +protected: + void readData(DataStream& ds, T& v) const override; + void writeData(DataStream& ds, T const& v) const override; + +private: + function<void(DataStream&, T&)> m_reader; + function<void(DataStream&, T const&)> m_writer; +}; + +typedef NetElementData<String> NetElementString; +typedef NetElementData<ByteArray> NetElementBytes; + +template <typename T> +T const& NetElementBasicField<T>::get() const { + return m_value; +} + +template <typename T> +void NetElementBasicField<T>::set(T const& value) { + if (!(m_value == value)) + push(value); +} + +template <typename T> +void NetElementBasicField<T>::push(T value) { + m_value = move(value); + updated(); + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + if (m_pendingInterpolatedValues) + m_pendingInterpolatedValues->clear(); +} + +template <typename T> +bool NetElementBasicField<T>::pullUpdated() { + return take(m_updated); +} + +template <typename T> +template <typename Mutator> +void NetElementBasicField<T>::update(Mutator&& mutator) { + if (mutator(m_value)) { + updated(); + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + if (m_pendingInterpolatedValues) + m_pendingInterpolatedValues->clear(); + } +} + +template <typename T> +void NetElementBasicField<T>::initNetVersion(NetElementVersion const* version) { + m_netVersion = version; + m_latestUpdateVersion = 0; +} + +template <typename T> +void NetElementBasicField<T>::enableNetInterpolation(float) { + if (!m_pendingInterpolatedValues) + m_pendingInterpolatedValues.emplace(); +} + +template <typename T> +void NetElementBasicField<T>::disableNetInterpolation() { + if (m_pendingInterpolatedValues) { + if (!m_pendingInterpolatedValues->empty()) + m_value = m_pendingInterpolatedValues->takeLast().second; + m_pendingInterpolatedValues.reset(); + } +} + +template <typename T> +void NetElementBasicField<T>::tickNetInterpolation(float dt) { + if (m_pendingInterpolatedValues) { + for (auto& p : *m_pendingInterpolatedValues) + p.first -= dt; + while (!m_pendingInterpolatedValues->empty() && m_pendingInterpolatedValues->first().first <= 0.0f) { + m_value = m_pendingInterpolatedValues->takeFirst().second; + updated(); + } + } +} + +template <typename T> +void NetElementBasicField<T>::netStore(DataStream& ds) const { + if (m_pendingInterpolatedValues && !m_pendingInterpolatedValues->empty()) + writeData(ds, m_pendingInterpolatedValues->last().second); + else + writeData(ds, m_value); +} + +template <typename T> +void NetElementBasicField<T>::netLoad(DataStream& ds) { + readData(ds, m_value); + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + updated(); + if (m_pendingInterpolatedValues) + m_pendingInterpolatedValues->clear(); +} + +template <typename T> +bool NetElementBasicField<T>::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + if (m_latestUpdateVersion < fromVersion) + return false; + + if (m_pendingInterpolatedValues && !m_pendingInterpolatedValues->empty()) + writeData(ds, m_pendingInterpolatedValues->last().second); + else + writeData(ds, m_value); + return true; +} + +template <typename T> +void NetElementBasicField<T>::readNetDelta(DataStream& ds, float interpolationTime) { + T t; + readData(ds, t); + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + if (m_pendingInterpolatedValues) { + // Only append an incoming delta to our pending value list if the incoming + // step is forward in time of every other pending value. In any other + // case, this is an error or the step tracking is wildly off, so just clear + // any other incoming values. + if (interpolationTime > 0.0f && (m_pendingInterpolatedValues->empty() || interpolationTime >= m_pendingInterpolatedValues->last().first)) { + m_pendingInterpolatedValues->append({interpolationTime, move(t)}); + } else { + m_value = move(t); + m_pendingInterpolatedValues->clear(); + updated(); + } + } else { + m_value = move(t); + updated(); + } +} + +template <typename T> +void NetElementBasicField<T>::updated() { + m_updated = true; +} + +template <typename T> +void NetElementIntegral<T>::readData(DataStream& ds, T& v) const { + if (sizeof(T) == 1) { + ds.read(v); + } else { + if (std::is_unsigned<T>::value) + v = ds.readVlqU(); + else + v = ds.readVlqI(); + } +} + +template <typename T> +void NetElementIntegral<T>::writeData(DataStream& ds, T const& v) const { + if (sizeof(T) == 1) { + ds.write(v); + } else { + if (std::is_unsigned<T>::value) + ds.writeVlqU(v); + else + ds.writeVlqI(v); + } +} + +template <typename Enum> +void NetElementEnum<Enum>::readData(DataStream& ds, Enum& v) const { + if (sizeof(Enum) == 1) + ds.read(v); + else + v = (Enum)ds.readVlqI(); +} + +template <typename Enum> +void NetElementEnum<Enum>::writeData(DataStream& ds, Enum const& v) const { + if (sizeof(Enum) == 1) + ds.write(v); + else + ds.writeVlqI((int64_t)v); +} + +template <typename T> +NetElementData<T>::NetElementData() + : NetElementData([](DataStream& ds, T & t) { ds >> t; }, [](DataStream& ds, T const& t) { ds << t; }) {} + +template <typename T> +NetElementData<T>::NetElementData(function<void(DataStream&, T&)> reader, function<void(DataStream&, T const&)> writer) + : m_reader(move(reader)), m_writer(move(writer)) {} + +template <typename T> +void NetElementData<T>::readData(DataStream& ds, T& v) const { + m_reader(ds, v); +} + +template <typename T> +void NetElementData<T>::writeData(DataStream& ds, T const& v) const { + m_writer(ds, v); +} + +} + +#endif diff --git a/source/core/StarNetElementContainers.hpp b/source/core/StarNetElementContainers.hpp new file mode 100644 index 0000000..5c2fd71 --- /dev/null +++ b/source/core/StarNetElementContainers.hpp @@ -0,0 +1,452 @@ +#ifndef STAR_NET_ELEMENT_CONTAINERS_HPP +#define STAR_NET_ELEMENT_CONTAINERS_HPP + +#include "StarMap.hpp" +#include "StarDataStreamExtra.hpp" +#include "StarNetElement.hpp" +#include "StarStrongTypedef.hpp" + +namespace Star { + +// NetElement map container that is more efficient than the naive serialization +// of an entire Map, because it delta encodes changes to save networking +// traffic. +template <typename BaseMap> +class NetElementMapWrapper : public NetElement, private BaseMap { +public: + typedef typename BaseMap::iterator iterator; + typedef typename BaseMap::const_iterator const_iterator; + + typedef typename BaseMap::key_type key_type; + typedef typename BaseMap::mapped_type mapped_type; + typedef typename BaseMap::value_type value_type; + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + + mapped_type const& get(key_type const& key) const; + mapped_type const* ptr(key_type const& key) const; + + const_iterator begin() const; + const_iterator end() const; + + using BaseMap::keys; + using BaseMap::values; + using BaseMap::pairs; + using BaseMap::contains; + using BaseMap::size; + using BaseMap::empty; + using BaseMap::maybe; + using BaseMap::value; + + pair<const_iterator, bool> insert(value_type v); + pair<const_iterator, bool> insert(key_type k, mapped_type v); + + void add(key_type k, mapped_type v); + // Calling set with a matching key and value does not cause a delta to be + // produced + void set(key_type k, mapped_type v); + // set requires that mapped_type implement operator==, push always generates + // a delta and does not require mapped_type operator== + void push(key_type k, mapped_type v); + + bool remove(key_type const& k); + + const_iterator erase(const_iterator i); + + mapped_type take(key_type const& k); + Maybe<mapped_type> maybeTake(key_type const& k); + + void clear(); + + BaseMap const& baseMap() const; + void reset(BaseMap values); + bool pullUpdated(); + + // Sets this map to contain the same keys / values as the given map. All + // values in this map not found in the given map are removed. (Same as + // reset, but with arbitrary map type). + template <typename MapType> + void setContents(MapType const& values); + +private: + // If a delta is written from further back than this many steps, the delta + // will fall back to a full serialization of the entire state. + static int64_t const MaxChangeDataVersions = 100; + + struct SetChange { + key_type key; + mapped_type value; + }; + struct RemoveChange { + key_type key; + }; + struct ClearChange {}; + + typedef Variant<SetChange, RemoveChange, ClearChange> ElementChange; + + static void writeChange(DataStream& ds, ElementChange const& change); + static ElementChange readChange(DataStream& ds); + + void addChangeData(ElementChange change); + + void addPendingChangeData(ElementChange change, float interpolationTime); + void applyChange(ElementChange change); + + Deque<pair<uint64_t, ElementChange>> m_changeData; + Deque<pair<float, ElementChange>> m_pendingChangeData; + NetElementVersion const* m_netVersion = nullptr; + uint64_t m_changeDataLastVersion = 0; + bool m_updated = false; + bool m_interpolationEnabled = false; +}; + +template <typename Key, typename Value> +using NetElementMap = NetElementMapWrapper<Map<Key, Value>>; + +template <typename Key, typename Value> +using NetElementHashMap = NetElementMapWrapper<HashMap<Key, Value>>; + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::initNetVersion(NetElementVersion const* version) { + m_netVersion = version; + + m_changeData.clear(); + m_changeDataLastVersion = 0; + + for (auto& change : Star::take(m_pendingChangeData)) + applyChange(move(change.second)); + + addChangeData(ClearChange()); + for (auto const& p : *this) + addChangeData(SetChange{p.first, p.second}); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::enableNetInterpolation(float) { + m_interpolationEnabled = true; +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::disableNetInterpolation() { + m_interpolationEnabled = false; + for (auto& change : Star::take(m_pendingChangeData)) + applyChange(move(change.second)); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::tickNetInterpolation(float dt) { + for (auto& p : m_pendingChangeData) + p.first -= dt; + + while (!m_pendingChangeData.empty() && m_pendingChangeData.first().first <= 0.0f) + applyChange(m_pendingChangeData.takeFirst().second); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::netStore(DataStream& ds) const { + ds.writeVlqU(BaseMap::size() + m_pendingChangeData.size()); + for (auto const& pair : *this) + writeChange(ds, SetChange{pair.first, pair.second}); + + for (auto const& p : m_pendingChangeData) + writeChange(ds, p.second); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::netLoad(DataStream& ds) { + m_changeData.clear(); + m_changeDataLastVersion = m_netVersion ? m_netVersion->current() : 0; + m_pendingChangeData.clear(); + BaseMap::clear(); + + addChangeData(ClearChange()); + + uint64_t count = ds.readVlqU(); + for (uint64_t i = 0; i < count; ++i) { + auto change = readChange(ds); + addChangeData(change); + applyChange(move(change)); + } + + m_updated = true; +} + +template <typename BaseMap> +bool NetElementMapWrapper<BaseMap>::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + bool deltaWritten = false; + + if (fromVersion < m_changeDataLastVersion) { + deltaWritten = true; + ds.writeVlqU(1); + netStore(ds); + + } else { + for (auto const& p : m_changeData) { + if (p.first >= fromVersion) { + deltaWritten = true; + ds.writeVlqU(2); + writeChange(ds, p.second); + } + } + } + + if (deltaWritten) + ds.writeVlqU(0); + + return deltaWritten; +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::readNetDelta(DataStream& ds, float interpolationTime) { + while (true) { + uint64_t code = ds.readVlqU(); + if (code == 0) { + break; + } else if (code == 1) { + netLoad(ds); + } else if (code == 2) { + auto change = readChange(ds); + addChangeData(change); + + if (m_interpolationEnabled && interpolationTime > 0.0f) + addPendingChangeData(move(change), interpolationTime); + else + applyChange(move(change)); + } else { + throw IOException("Improper delta code received in NetElementMapWrapper::readNetDelta"); + } + } +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::get(key_type const& key) const -> mapped_type const & { + return BaseMap::get(key); +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::ptr(key_type const& key) const -> mapped_type const * { + return BaseMap::ptr(key); +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::begin() const -> const_iterator { + return BaseMap::begin(); +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::end() const -> const_iterator { + return BaseMap::end(); +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::insert(value_type v) -> pair<const_iterator, bool> { + auto res = BaseMap::insert(v); + if (res.second) { + addChangeData(SetChange{move(v.first), move(v.second)}); + m_updated = true; + } + return res; +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::insert(key_type k, mapped_type v) -> pair<const_iterator, bool> { + return insert(value_type(move(k), move(v))); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::add(key_type k, mapped_type v) { + if (!insert(value_type(move(k), move(v))).second) + throw MapException::format("Entry with key '%s' already present.", outputAny(k)); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::set(key_type k, mapped_type v) { + auto i = BaseMap::find(k); + if (i != BaseMap::end()) { + if (!(i->second == v)) { + addChangeData(SetChange{move(k), v}); + i->second = move(v); + m_updated = true; + } + } else { + addChangeData(SetChange{k, v}); + BaseMap::insert(value_type(move(k), move(v))); + m_updated = true; + } +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::push(key_type k, mapped_type v) { + auto i = BaseMap::find(k); + if (i != BaseMap::end()) { + addChangeData(SetChange(move(k), v)); + i->second = move(v); + } else { + addChangeData(SetChange(k, v)); + BaseMap::insert(value_type(move(k), move(v))); + } + m_updated = true; +} + +template <typename BaseMap> +bool NetElementMapWrapper<BaseMap>::remove(key_type const& k) { + auto i = BaseMap::find(k); + if (i != BaseMap::end()) { + BaseMap::erase(i); + addChangeData(RemoveChange{k}); + m_updated = true; + return true; + } + return false; +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::erase(const_iterator i) -> const_iterator { + addChangeData(RemoveChange(i->first)); + m_updated = true; + return BaseMap::erase(i); +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::take(key_type const& k) -> mapped_type { + auto i = BaseMap::find(k); + if (i == BaseMap::end()) + throw MapException::format("Key '%s' not found in Map::take()", outputAny(k)); + auto m = move(i->second); + erase(i); + return m; +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::maybeTake(key_type const& k) -> Maybe<mapped_type> { + auto i = BaseMap::find(k); + if (i == BaseMap::end()) + return {}; + auto m = move(i->second); + erase(i); + return Maybe<mapped_type>(move(m)); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::clear() { + if (!empty()) { + addChangeData(ClearChange()); + m_updated = true; + BaseMap::clear(); + } +} + +template <typename BaseMap> +BaseMap const& NetElementMapWrapper<BaseMap>::baseMap() const { + return *this; +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::reset(BaseMap values) { + for (auto const& p : *this) { + if (!values.contains(p.first)) { + addChangeData(RemoveChange{p.first}); + m_updated = true; + } + } + + for (auto const& p : values) { + auto v = ptr(p.first); + if (!v || !(*v == p.second)) { + addChangeData(SetChange{p.first, p.second}); + m_updated = true; + } + } + + BaseMap::operator=(move(values)); +} + +template <typename BaseMap> +bool NetElementMapWrapper<BaseMap>::pullUpdated() { + return Star::take(m_updated); +} + +template <typename BaseMap> +template <typename MapType> +void NetElementMapWrapper<BaseMap>::setContents(MapType const& values) { + reset(BaseMap::from(values)); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::writeChange(DataStream& ds, ElementChange const& change) { + if (auto sc = change.template ptr<SetChange>()) { + ds.write<uint8_t>(0); + ds.write(sc->key); + ds.write(sc->value); + } else if (auto rc = change.template ptr<RemoveChange>()) { + ds.write<uint8_t>(1); + ds.write(rc->key); + } else { + ds.write<uint8_t>(2); + } +} + +template <typename BaseMap> +auto NetElementMapWrapper<BaseMap>::readChange(DataStream& ds) -> ElementChange { + uint8_t t = ds.read<uint8_t>(); + if (t == 0) { + SetChange sc; + ds.read(sc.key); + ds.read(sc.value); + return sc; + } else if (t == 1) { + RemoveChange rc; + ds.read(rc.key); + return rc; + } else if (t == 2) { + return ClearChange(); + } else { + throw IOException("Improper type code received in NetElementMapWrapper::readChange"); + } +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::addChangeData(ElementChange change) { + uint64_t currentVersion = m_netVersion ? m_netVersion->current() : 0; + starAssert(m_changeData.empty() || m_changeData.last().first <= currentVersion); + + m_changeData.append({currentVersion, move(change)}); + + m_changeDataLastVersion = max<int64_t>((int64_t)currentVersion - MaxChangeDataVersions, 0); + while (!m_changeData.empty() && m_changeData.first().first < m_changeDataLastVersion) + m_changeData.removeFirst(); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::addPendingChangeData(ElementChange change, float interpolationTime) { + if (!m_pendingChangeData.empty() && interpolationTime < m_pendingChangeData.last().first) { + for (auto& change : Star::take(m_pendingChangeData)) + applyChange(move(change.second)); + } + m_pendingChangeData.append({interpolationTime, move(change)}); +} + +template <typename BaseMap> +void NetElementMapWrapper<BaseMap>::applyChange(ElementChange change) { + if (auto set = change.template ptr<SetChange>()) + BaseMap::set(move(set->key), move(set->value)); + else if (auto remove = change.template ptr<RemoveChange>()) + BaseMap::remove(move(remove->key)); + else + BaseMap::clear(); + m_updated = true; +} + +} + +#endif diff --git a/source/core/StarNetElementDynamicGroup.hpp b/source/core/StarNetElementDynamicGroup.hpp new file mode 100644 index 0000000..5d34c8d --- /dev/null +++ b/source/core/StarNetElementDynamicGroup.hpp @@ -0,0 +1,317 @@ +#ifndef STAR_NET_ELEMENT_DYNAMIC_GROUP_HPP +#define STAR_NET_ELEMENT_DYNAMIC_GROUP_HPP + +#include "StarNetElement.hpp" +#include "StarIdMap.hpp" +#include "StarStrongTypedef.hpp" +#include "StarDataStreamExtra.hpp" + +namespace Star { + +// A dynamic group of NetElements that manages creation and destruction of +// individual elements, that is itself a NetElement. Element changes are not +// delayed by the interpolation delay, they will always happen immediately, but +// this does not inhibit the Elements themselves from handling their own delta +// update delays normally. +template <typename Element> +class NetElementDynamicGroup : public NetElement { +public: + typedef shared_ptr<Element> ElementPtr; + typedef uint32_t ElementId; + static ElementId const NullElementId = 0; + + NetElementDynamicGroup() = default; + + NetElementDynamicGroup(NetElementDynamicGroup const&) = delete; + NetElementDynamicGroup& operator=(NetElementDynamicGroup const&) = delete; + + // Must not call addNetElement / removeNetElement when being used as a slave, + // id errors will result. + ElementId addNetElement(ElementPtr element); + void removeNetElement(ElementId id); + + // Remove all elements + void clearNetElements(); + + List<ElementId> netElementIds() const; + ElementPtr getNetElement(ElementId id) const; + + List<ElementPtr> netElements() const; + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + // Values are never interpolated, but they will be delayed for the given + // interpolationTime. + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + void blankNetDelta(float interpolationTime = 0.0f) override; + +private: + // If a delta is written from further back than this many versions, the delta + // will fall back to a full serialization of the entire state. + static int64_t const MaxChangeDataVersions = 100; + + typedef ElementId ElementRemovalType; + typedef pair<ElementId, ByteArray> ElementAdditionType; + + strong_typedef(Empty, ElementReset); + strong_typedef_builtin(ElementRemovalType, ElementRemoval); + strong_typedef(ElementAdditionType, ElementAddition); + + typedef Variant<ElementReset, ElementRemoval, ElementAddition> ElementChange; + + typedef IdMap<ElementId, ElementPtr> ElementMap; + + void addChangeData(ElementChange change); + + void readyElement(ElementPtr const& element); + + NetElementVersion const* m_netVersion = nullptr; + bool m_interpolationEnabled = false; + float m_extrapolationHint = 0.0f; + + ElementMap m_idMap = ElementMap(1, highest<ElementId>()); + + Deque<pair<uint64_t, ElementChange>> m_changeData; + uint64_t m_changeDataLastVersion = 0; + + mutable DataStreamBuffer m_buffer; + mutable HashSet<ElementId> m_receivedDeltaIds; +}; + +template <typename Element> +auto NetElementDynamicGroup<Element>::addNetElement(ElementPtr element) -> ElementId { + readyElement(element); + DataStreamBuffer storeBuffer; + element->netStore(storeBuffer); + auto id = m_idMap.add(move(element)); + + addChangeData(ElementAddition(id, storeBuffer.takeData())); + + return id; +} + +template <typename Element> +void NetElementDynamicGroup<Element>::removeNetElement(ElementId id) { + m_idMap.remove(id); + addChangeData(ElementRemoval{id}); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::clearNetElements() { + for (auto const& id : netElementIds()) + removeNetElement(id); +} + +template <typename Element> +auto NetElementDynamicGroup<Element>::netElementIds() const -> List<ElementId> { + return m_idMap.keys(); +} + +template <typename Element> +auto NetElementDynamicGroup<Element>::getNetElement(ElementId id) const -> ElementPtr { + return m_idMap.get(id); +} + +template <typename Element> +auto NetElementDynamicGroup<Element>::netElements() const -> List<ElementPtr> { + return m_idMap.values(); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::initNetVersion(NetElementVersion const* version) { + m_netVersion = version; + m_changeData.clear(); + m_changeDataLastVersion = 0; + + addChangeData(ElementReset()); + for (auto& pair : m_idMap) { + pair.second->initNetVersion(m_netVersion); + DataStreamBuffer storeBuffer; + pair.second->netStore(storeBuffer); + addChangeData(ElementAddition(pair.first, storeBuffer.takeData())); + } +} + +template <typename Element> +void NetElementDynamicGroup<Element>::enableNetInterpolation(float extrapolationHint) { + m_interpolationEnabled = true; + m_extrapolationHint = extrapolationHint; + for (auto& p : m_idMap) + p.second->enableNetInterpolation(extrapolationHint); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::disableNetInterpolation() { + m_interpolationEnabled = false; + m_extrapolationHint = 0.0f; + for (auto& p : m_idMap) + p.second->disableNetInterpolation(); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::tickNetInterpolation(float dt) { + for (auto& p : m_idMap) + p.second->tickNetInterpolation(dt); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::netStore(DataStream& ds) const { + ds.writeVlqU(m_idMap.size()); + + for (auto& pair : m_idMap) { + ds.writeVlqU(pair.first); + pair.second->netStore(m_buffer); + ds.write(m_buffer.data()); + m_buffer.clear(); + } +} + +template <typename Element> +void NetElementDynamicGroup<Element>::netLoad(DataStream& ds) { + m_changeData.clear(); + m_changeDataLastVersion = m_netVersion ? m_netVersion->current() : 0; + m_idMap.clear(); + + addChangeData(ElementReset()); + + uint64_t count = ds.readVlqU(); + + for (uint64_t i = 0; i < count; ++i) { + ElementId id = ds.readVlqU(); + DataStreamBuffer storeBuffer(ds.read<ByteArray>()); + + ElementPtr element = make_shared<Element>(); + element->netLoad(storeBuffer); + readyElement(element); + + m_idMap.add(id, move(element)); + addChangeData(ElementAddition(id, storeBuffer.takeData())); + } +} + +template <typename Element> +bool NetElementDynamicGroup<Element>::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + if (fromVersion < m_changeDataLastVersion) { + ds.write<bool>(true); + netStore(ds); + return true; + + } else { + bool deltaWritten = false; + auto willWrite = [&]() { + if (!deltaWritten) { + deltaWritten = true; + ds.write<bool>(false); + } + }; + + for (auto const& p : m_changeData) { + if (p.first >= fromVersion) { + willWrite(); + ds.writeVlqU(1); + ds.write(p.second); + } + } + + for (auto& p : m_idMap) { + if (p.second->writeNetDelta(m_buffer, fromVersion)) { + willWrite(); + ds.writeVlqU(p.first + 1); + ds.writeBytes(m_buffer.data()); + m_buffer.clear(); + } + } + + if (deltaWritten) + ds.writeVlqU(0); + + return deltaWritten; + } +} + +template <typename Element> +void NetElementDynamicGroup<Element>::readNetDelta(DataStream& ds, float interpolationTime) { + bool isFull = ds.read<bool>(); + if (isFull) { + netLoad(ds); + } else { + while (true) { + uint64_t code = ds.readVlqU(); + if (code == 0) { + break; + } + if (code == 1) { + auto changeUpdate = ds.read<ElementChange>(); + addChangeData(changeUpdate); + + if (changeUpdate.template is<ElementReset>()) { + m_idMap.clear(); + } else if (auto addition = changeUpdate.template ptr<ElementAddition>()) { + ElementPtr element = make_shared<Element>(); + DataStreamBuffer storeBuffer(move(get<1>(*addition))); + element->netLoad(storeBuffer); + readyElement(element); + m_idMap.add(get<0>(*addition), move(element)); + } else if (auto removal = changeUpdate.template ptr<ElementRemoval>()) { + m_idMap.remove(*removal); + } + } else { + ElementId elementId = code - 1; + auto const& element = m_idMap.get(elementId); + element->readNetDelta(ds, interpolationTime); + if (m_interpolationEnabled) + m_receivedDeltaIds.add(elementId); + } + } + + if (m_interpolationEnabled) { + for (auto& p : m_idMap) { + if (!m_receivedDeltaIds.contains(p.first)) + p.second->blankNetDelta(interpolationTime); + } + + m_receivedDeltaIds.clear(); + } + } +} + +template <typename Element> +void NetElementDynamicGroup<Element>::blankNetDelta(float interpolationTime) { + if (m_interpolationEnabled) { + for (auto& p : m_idMap) + p.second->blankNetDelta(interpolationTime); + } +} + +template <typename Element> +void NetElementDynamicGroup<Element>::addChangeData(ElementChange change) { + uint64_t currentVersion = m_netVersion ? m_netVersion->current() : 0; + starAssert(m_changeData.empty() || m_changeData.last().first <= currentVersion); + + m_changeData.append({currentVersion, move(change)}); + + m_changeDataLastVersion = max<int64_t>((int64_t)currentVersion - MaxChangeDataVersions, 0); + while (!m_changeData.empty() && m_changeData.first().first < m_changeDataLastVersion) + m_changeData.removeFirst(); +} + +template <typename Element> +void NetElementDynamicGroup<Element>::readyElement(ElementPtr const& element) { + element->initNetVersion(m_netVersion); + if (m_interpolationEnabled) + element->enableNetInterpolation(m_extrapolationHint); + else + element->disableNetInterpolation(); +} + +} + +#endif diff --git a/source/core/StarNetElementFloatFields.hpp b/source/core/StarNetElementFloatFields.hpp new file mode 100644 index 0000000..835b3ec --- /dev/null +++ b/source/core/StarNetElementFloatFields.hpp @@ -0,0 +1,246 @@ +#ifndef STAR_NET_ELEMENT_FLOAT_FIELDS_HPP +#define STAR_NET_ELEMENT_FLOAT_FIELDS_HPP + +#include <type_traits> + +#include "StarNetElement.hpp" +#include "StarInterpolation.hpp" + +namespace Star { + +STAR_EXCEPTION(StepStreamException, StarException); + +template <typename T> +class NetElementFloating : public NetElement { +public: + T get() const; + void set(T value); + + // If a fixed point base is given, then instead of transmitting the value as + // a float, it is transmitted as a VLQ of the value divided by the fixed + // point base. Any NetElementFloating that is transmitted to must also have + // the same fixed point base set. + void setFixedPointBase(Maybe<T> fixedPointBase = {}); + + // If interpolation is enabled on the NetStepStates parent, and an + // interpolator is set, then on steps in between data points this will be + // used to interpolate this value. It is not necessary that senders and + // receivers both have matching interpolation functions, or any interpolation + // functions at all. + void setInterpolator(function<T(T, T, T)> interpolator); + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + // Values are never interpolated, but they will be delayed for the given + // interpolationTime. + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + void blankNetDelta(float interpolationTime = 0.0f) override; + +private: + void writeValue(DataStream& ds, T t) const; + T readValue(DataStream& ds) const; + + T interpolate() const; + + Maybe<T> m_fixedPointBase; + NetElementVersion const* m_netVersion = nullptr; + uint64_t m_latestUpdateVersion = 0; + T m_value = T(); + + function<T(T, T, T)> m_interpolator; + float m_extrapolation = 0.0f; + Maybe<Deque<pair<float, T>>> m_interpolationDataPoints; +}; + +typedef NetElementFloating<float> NetElementFloat; +typedef NetElementFloating<double> NetElementDouble; + +template <typename T> +T NetElementFloating<T>::get() const { + return m_value; +} + +template <typename T> +void NetElementFloating<T>::set(T value) { + if (m_value != value) { + // Only mark the step as updated here if it actually would change the + // transmitted value. + if (!m_fixedPointBase || round(m_value / *m_fixedPointBase) != round(value / *m_fixedPointBase)) + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + + m_value = value; + + if (m_interpolationDataPoints) { + m_interpolationDataPoints->clear(); + m_interpolationDataPoints->append({0.0f, m_value}); + } + } +} + +template <typename T> +void NetElementFloating<T>::setFixedPointBase(Maybe<T> fixedPointBase) { + m_fixedPointBase = fixedPointBase; +} + +template <typename T> +void NetElementFloating<T>::setInterpolator(function<T(T, T, T)> interpolator) { + m_interpolator = move(interpolator); +} + +template <typename T> +void NetElementFloating<T>::initNetVersion(NetElementVersion const* version) { + m_netVersion = version; + m_latestUpdateVersion = 0; +} + +template <typename T> +void NetElementFloating<T>::enableNetInterpolation(float extrapolationHint) { + m_extrapolation = extrapolationHint; + if (!m_interpolationDataPoints) { + m_interpolationDataPoints.emplace(); + m_interpolationDataPoints->append({0.0f, m_value}); + } +} + +template <typename T> +void NetElementFloating<T>::disableNetInterpolation() { + if (m_interpolationDataPoints) { + m_value = m_interpolationDataPoints->last().second; + m_interpolationDataPoints.reset(); + } +} + +template <typename T> +void NetElementFloating<T>::tickNetInterpolation(float dt) { + if (m_interpolationDataPoints) { + for (auto& p : *m_interpolationDataPoints) + p.first -= dt; + + while (m_interpolationDataPoints->size() > 2 && (*m_interpolationDataPoints)[1].first <= 0.0f) + m_interpolationDataPoints->removeFirst(); + + m_value = interpolate(); + } +} + +template <typename T> +void NetElementFloating<T>::netStore(DataStream& ds) const { + if (m_interpolationDataPoints) + writeValue(ds, m_interpolationDataPoints->last().second); + else + writeValue(ds, m_value); +} + +template <typename T> +void NetElementFloating<T>::netLoad(DataStream& ds) { + m_value = readValue(ds); + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + if (m_interpolationDataPoints) { + m_interpolationDataPoints->clear(); + m_interpolationDataPoints->append({0.0f, m_value}); + } +} + +template <typename T> +bool NetElementFloating<T>::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + if (m_latestUpdateVersion < fromVersion) + return false; + + if (m_interpolationDataPoints) + writeValue(ds, m_interpolationDataPoints->last().second); + else + writeValue(ds, m_value); + + return true; +} + +template <typename T> +void NetElementFloating<T>::readNetDelta(DataStream& ds, float interpolationTime) { + T t = readValue(ds); + + m_latestUpdateVersion = m_netVersion ? m_netVersion->current() : 0; + if (m_interpolationDataPoints) { + if (interpolationTime < m_interpolationDataPoints->last().first) + m_interpolationDataPoints->clear(); + m_interpolationDataPoints->append({interpolationTime, t}); + m_value = interpolate(); + } else { + m_value = t; + } +} + +template <typename T> +void NetElementFloating<T>::blankNetDelta(float interpolationTime) { + if (m_interpolationDataPoints) { + auto lastPoint = m_interpolationDataPoints->last(); + float lastTime = lastPoint.first; + lastPoint.first = interpolationTime; + if (interpolationTime < lastTime) + *m_interpolationDataPoints = {lastPoint}; + else + m_interpolationDataPoints->append(lastPoint); + + m_value = interpolate(); + } +} + +template <typename T> +void NetElementFloating<T>::writeValue(DataStream& ds, T t) const { + if (m_fixedPointBase) + ds.writeVlqI(round(t / *m_fixedPointBase)); + else + ds.write(t); +} + +template <typename T> +T NetElementFloating<T>::readValue(DataStream& ds) const { + T t; + if (m_fixedPointBase) + t = ds.readVlqI() * *m_fixedPointBase; + else + ds.read(t); + return t; +} + +template <typename T> +T NetElementFloating<T>::interpolate() const { + auto& dataPoints = *m_interpolationDataPoints; + + float ipos = inverseLinearInterpolateUpper(dataPoints.begin(), dataPoints.end(), 0.0f, + [](float lhs, auto const& rhs) { + return lhs < rhs.first; + }, [](auto const& dataPoint) { + return dataPoint.first; + }); + auto bound = getBound2(ipos, dataPoints.size(), BoundMode::Extrapolate); + + if (m_interpolator) { + auto const& minPoint = dataPoints[bound.i0]; + auto const& maxPoint = dataPoints[bound.i1]; + + // If step separation is less than 1.0, don't normalize extrapolation to + // the very small step difference, because this can result in large jumps + // during jitter. + float stepDist = max(maxPoint.first - minPoint.first, 1.0f); + float offset = clamp<float>(bound.offset, 0.0f, 1.0f + m_extrapolation / stepDist); + return m_interpolator(offset, minPoint.second, maxPoint.second); + + } else { + if (bound.offset < 1.0f) + return dataPoints[bound.i0].second; + else + return dataPoints[bound.i1].second; + } +} + +} + +#endif diff --git a/source/core/StarNetElementGroup.cpp b/source/core/StarNetElementGroup.cpp new file mode 100644 index 0000000..d8ff6c4 --- /dev/null +++ b/source/core/StarNetElementGroup.cpp @@ -0,0 +1,108 @@ +#include "StarNetElementGroup.hpp" + +namespace Star { + +void NetElementGroup::addNetElement(NetElement* element, bool propagateInterpolation) { + starAssert(!m_elements.any([element](auto p) { return p.first == element; })); + + element->initNetVersion(m_version); + if (m_interpolationEnabled && propagateInterpolation) + element->enableNetInterpolation(m_extrapolationHint); + m_elements.append(pair<NetElement*, bool>(element, propagateInterpolation)); +} + +void NetElementGroup::clearNetElements() { + m_elements.clear(); +} + +void NetElementGroup::initNetVersion(NetElementVersion const* version) { + m_version = version; + for (auto p : m_elements) + p.first->initNetVersion(m_version); +} + +void NetElementGroup::netStore(DataStream& ds) const { + for (auto p : m_elements) + p.first->netStore(ds); +} + +void NetElementGroup::netLoad(DataStream& ds) { + for (auto p : m_elements) + p.first->netLoad(ds); +} + +void NetElementGroup::enableNetInterpolation(float extrapolationHint) { + m_interpolationEnabled = true; + m_extrapolationHint = extrapolationHint; + for (auto p : m_elements) { + if (p.second) + p.first->enableNetInterpolation(extrapolationHint); + } +} + +void NetElementGroup::disableNetInterpolation() { + m_interpolationEnabled = false; + m_extrapolationHint = 0; + for (auto p : m_elements) { + if (p.second) + p.first->disableNetInterpolation(); + } +} + +void NetElementGroup::tickNetInterpolation(float dt) { + if (m_interpolationEnabled) { + for (auto p : m_elements) + p.first->tickNetInterpolation(dt); + } +} + +bool NetElementGroup::writeNetDelta(DataStream& ds, uint64_t fromStep) const { + if (m_elements.size() == 0) { + return false; + } else if (m_elements.size() == 1) { + return m_elements[0].first->writeNetDelta(ds, fromStep); + } else { + bool deltaWritten = false; + for (uint64_t i = 0; i < m_elements.size(); ++i) { + if (m_elements[i].first->writeNetDelta(m_buffer, fromStep)) { + deltaWritten = true; + ds.writeVlqU(i + 1); + ds.writeBytes(m_buffer.data()); + m_buffer.clear(); + } + } + if (deltaWritten) + ds.writeVlqU(0); + return deltaWritten; + } +} + +void NetElementGroup::readNetDelta(DataStream& ds, float interpolationTime) { + if (m_elements.size() == 0) { + throw IOException("readNetDelta called on empty NetElementGroup"); + } else if (m_elements.size() == 1) { + m_elements[0].first->readNetDelta(ds, interpolationTime); + } else { + uint64_t readIndex = ds.readVlqU(); + for (uint64_t i = 0; i < m_elements.size(); ++i) { + if (readIndex == 0 || readIndex - 1 > i) { + if (m_interpolationEnabled) + m_elements[i].first->blankNetDelta(interpolationTime); + } else if (readIndex - 1 == i) { + m_elements[i].first->readNetDelta(ds, interpolationTime); + readIndex = ds.readVlqU(); + } else { + throw IOException("group indexes out of order in NetElementGroup::readNetDelta"); + } + } + } +} + +void NetElementGroup::blankNetDelta(float interpolationTime) { + if (m_interpolationEnabled) { + for (auto p : m_elements) + p.first->blankNetDelta(interpolationTime); + } +} + +} diff --git a/source/core/StarNetElementGroup.hpp b/source/core/StarNetElementGroup.hpp new file mode 100644 index 0000000..d583818 --- /dev/null +++ b/source/core/StarNetElementGroup.hpp @@ -0,0 +1,66 @@ +#ifndef STAR_NET_ELEMENT_GROUP_HPP +#define STAR_NET_ELEMENT_GROUP_HPP + +#include "StarSet.hpp" +#include "StarNetElement.hpp" +#include "StarDataStreamDevices.hpp" + +namespace Star { + +// A static group of NetElements that itself is a NetElement and serializes +// changes based on the order in which elements are added. All participants +// must externally add elements of the correct type in the correct order. +class NetElementGroup : public NetElement { +public: + NetElementGroup() = default; + + NetElementGroup(NetElementGroup const&) = delete; + NetElementGroup& operator=(NetElementGroup const&) = delete; + + // Add an element to the group. + void addNetElement(NetElement* element, bool propagateInterpolation = true); + + // Removes all previously added elements + void clearNetElements(); + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + void blankNetDelta(float interpolationTime) override; + + NetElementVersion const* netVersion() const; + bool netInterpolationEnabled() const; + float netExtrapolationHint() const; + +private: + List<pair<NetElement*, bool>> m_elements; + NetElementVersion const* m_version = nullptr; + bool m_interpolationEnabled = false; + float m_extrapolationHint = 0.0f; + + mutable DataStreamBuffer m_buffer; +}; + +inline NetElementVersion const* NetElementGroup::netVersion() const { + return m_version; +} + +inline bool NetElementGroup::netInterpolationEnabled() const { + return m_interpolationEnabled; +} + +inline float NetElementGroup::netExtrapolationHint() const { + return m_extrapolationHint; +} + +} + +#endif diff --git a/source/core/StarNetElementSignal.hpp b/source/core/StarNetElementSignal.hpp new file mode 100644 index 0000000..c1a89ac --- /dev/null +++ b/source/core/StarNetElementSignal.hpp @@ -0,0 +1,145 @@ +#ifndef STAR_NET_ELEMENT_SIGNAL_HPP +#define STAR_NET_ELEMENT_SIGNAL_HPP + +#include "StarNetElement.hpp" + +namespace Star { + +// NetElement that sends signals during delta writes that can be received by +// slaves. It has no 'state', and nothing is sent during a store / load, and +// it only keeps past signals for a maximum number of versions. Thus, it is +// not appropriate to use to send updates to long term states, only for event +// like things that are not harmful if missed. +template <typename Signal> +class NetElementSignal : public NetElement { +public: + NetElementSignal(size_t maxSignalQueue = 32); + + void initNetVersion(NetElementVersion const* version = nullptr) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromVersion) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0) override; + + void send(Signal signal); + List<Signal> receive(); + +private: + struct SignalEntry { + uint64_t version; + Signal signal; + bool received; + }; + + size_t m_maxSignalQueue; + NetElementVersion const* m_netVersion = nullptr; + bool m_netInterpolationEnabled = false; + Deque<SignalEntry> m_signals; + Deque<pair<float, Signal>> m_pendingSignals; +}; + +template <typename Signal> +NetElementSignal<Signal>::NetElementSignal(size_t maxSignalQueue) { + m_maxSignalQueue = maxSignalQueue; +} + +template <typename Signal> +void NetElementSignal<Signal>::initNetVersion(NetElementVersion const* version) { + m_netVersion = version; + m_signals.clear(); +} + +template <typename Signal> +void NetElementSignal<Signal>::netStore(DataStream&) const {} + +template <typename Signal> +void NetElementSignal<Signal>::netLoad(DataStream&) { +} + +template <typename Signal> +void NetElementSignal<Signal>::enableNetInterpolation(float) { + m_netInterpolationEnabled = true; +} + +template <typename Signal> +void NetElementSignal<Signal>::disableNetInterpolation() { + m_netInterpolationEnabled = false; + for (auto& p : take(m_pendingSignals)) + send(move(p.second)); +} + +template <typename Signal> +void NetElementSignal<Signal>::tickNetInterpolation(float dt) { + for (auto& p : m_pendingSignals) + p.first -= dt; + + while (!m_pendingSignals.empty() && m_pendingSignals.first().first <= 0.0f) + send(m_pendingSignals.takeFirst().second); +} + +template <typename Signal> +bool NetElementSignal<Signal>::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + size_t numToWrite = 0; + for (auto const& p : m_signals) { + if (p.version >= fromVersion) + ++numToWrite; + } + if (numToWrite == 0) + return false; + + ds.writeVlqU(numToWrite); + + for (auto const& p : m_signals) { + if (p.version >= fromVersion) + ds.write(p.signal); + } + + return true; +} + +template <typename Signal> +void NetElementSignal<Signal>::readNetDelta(DataStream& ds, float interpolationTime) { + size_t numToRead = ds.readVlqU(); + for (size_t i = 0; i < numToRead; ++i) { + Signal s; + ds.read(s); + if (m_netInterpolationEnabled && interpolationTime > 0.0f) { + if (!m_pendingSignals.empty() && m_pendingSignals.last().first > interpolationTime) { + for (auto& p : take(m_pendingSignals)) + send(move(p.second)); + } + m_pendingSignals.append({interpolationTime, move(s)}); + } else { + send(move(s)); + } + } +} + +template <typename Signal> +void NetElementSignal<Signal>::send(Signal signal) { + m_signals.append({m_netVersion ? m_netVersion->current() : 0, signal, false}); + while (m_signals.size() > m_maxSignalQueue) + m_signals.removeFirst(); +} + +template <typename Signal> +List<Signal> NetElementSignal<Signal>::receive() { + List<Signal> received; + for (auto& p : m_signals) { + if (!p.received) { + received.append(p.signal); + p.received = true; + } + } + return received; +} + +} + +#endif diff --git a/source/core/StarNetElementSyncGroup.cpp b/source/core/StarNetElementSyncGroup.cpp new file mode 100644 index 0000000..4ead00f --- /dev/null +++ b/source/core/StarNetElementSyncGroup.cpp @@ -0,0 +1,88 @@ +#include "StarNetElementSyncGroup.hpp" + +namespace Star { + +void NetElementSyncGroup::enableNetInterpolation(float extrapolationHint) { + NetElementGroup::enableNetInterpolation(extrapolationHint); + if (m_hasRecentChanges) + netElementsNeedLoad(false); +} + +void NetElementSyncGroup::disableNetInterpolation() { + NetElementGroup::disableNetInterpolation(); + if (m_hasRecentChanges) + netElementsNeedLoad(false); +} + +void NetElementSyncGroup::tickNetInterpolation(float dt) { + NetElementGroup::tickNetInterpolation(dt); + if (m_hasRecentChanges) { + m_recentDeltaTime -= dt; + if (netInterpolationEnabled()) + netElementsNeedLoad(false); + + if (m_recentDeltaTime < 0.0f && m_recentDeltaWasBlank) + m_hasRecentChanges = false; + } +} + +void NetElementSyncGroup::netStore(DataStream& ds) const { + const_cast<NetElementSyncGroup*>(this)->netElementsNeedStore(); + return NetElementGroup::netStore(ds); +} + +void NetElementSyncGroup::netLoad(DataStream& ds) { + NetElementGroup::netLoad(ds); + netElementsNeedLoad(true); +} + +bool NetElementSyncGroup::writeNetDelta(DataStream& ds, uint64_t fromVersion) const { + const_cast<NetElementSyncGroup*>(this)->netElementsNeedStore(); + return NetElementGroup::writeNetDelta(ds, fromVersion); +} + +void NetElementSyncGroup::readNetDelta(DataStream& ds, float interpolationTime) { + NetElementGroup::readNetDelta(ds, interpolationTime); + + m_hasRecentChanges = true; + m_recentDeltaTime = interpolationTime; + m_recentDeltaWasBlank = false; + + netElementsNeedLoad(false); +} + +void NetElementSyncGroup::blankNetDelta(float interpolationTime) { + NetElementGroup::blankNetDelta(interpolationTime); + + if (!m_recentDeltaWasBlank) { + m_recentDeltaTime = interpolationTime; + m_recentDeltaWasBlank = true; + } + + if (m_hasRecentChanges && netInterpolationEnabled()) + netElementsNeedLoad(false); +} + +void NetElementSyncGroup::netElementsNeedLoad(bool) {} + +void NetElementSyncGroup::netElementsNeedStore() {} + +void NetElementCallbackGroup::setNeedsLoadCallback(function<void(bool)> needsLoadCallback) { + m_netElementsNeedLoad = move(needsLoadCallback); +} + +void NetElementCallbackGroup::setNeedsStoreCallback(function<void()> needsStoreCallback) { + m_netElementsNeedStore = move(needsStoreCallback); +} + +void NetElementCallbackGroup::netElementsNeedLoad(bool load) { + if (m_netElementsNeedLoad) + m_netElementsNeedLoad(load); +} + +void NetElementCallbackGroup::netElementsNeedStore() { + if (m_netElementsNeedStore) + m_netElementsNeedStore(); +} + +} diff --git a/source/core/StarNetElementSyncGroup.hpp b/source/core/StarNetElementSyncGroup.hpp new file mode 100644 index 0000000..2162ebe --- /dev/null +++ b/source/core/StarNetElementSyncGroup.hpp @@ -0,0 +1,54 @@ +#ifndef STAR_NET_ELEMENT_SYNC_GROUP_HPP +#define STAR_NET_ELEMENT_SYNC_GROUP_HPP + +#include "StarNetElementGroup.hpp" + +namespace Star { + +// NetElementGroup class that works with NetElements that are not automatically +// kept up to date with working data, and users need to be notified when to +// synchronize with working data. +class NetElementSyncGroup : public NetElementGroup { +public: + void enableNetInterpolation(float extrapolationHint = 0.0f) override; + void disableNetInterpolation() override; + void tickNetInterpolation(float dt) override; + + void netStore(DataStream& ds) const override; + void netLoad(DataStream& ds) override; + + bool writeNetDelta(DataStream& ds, uint64_t fromStep) const override; + void readNetDelta(DataStream& ds, float interpolationTime = 0.0f) override; + void blankNetDelta(float interpolationTime = 0.0f) override; + +protected: + // Notifies when data needs to be pulled from NetElements, load is true if + // this is due to a netLoad call + virtual void netElementsNeedLoad(bool load); + // Notifies when data needs to be pushed to NetElements + virtual void netElementsNeedStore(); + +private: + bool m_hasRecentChanges = false; + float m_recentDeltaTime = 0.0f; + bool m_recentDeltaWasBlank = false; +}; + +// Same as a NetElementSyncGroup, except instead of protected methods, calls +// optional callback functions. +class NetElementCallbackGroup : public NetElementSyncGroup { +public: + void setNeedsLoadCallback(function<void(bool)> needsLoadCallback); + void setNeedsStoreCallback(function<void()> needsStoreCallback); + +private: + void netElementsNeedLoad(bool load) override; + void netElementsNeedStore() override; + + function<void(bool)> m_netElementsNeedLoad; + function<void()> m_netElementsNeedStore; +}; + +} + +#endif diff --git a/source/core/StarNetElementSystem.hpp b/source/core/StarNetElementSystem.hpp new file mode 100644 index 0000000..d5a3560 --- /dev/null +++ b/source/core/StarNetElementSystem.hpp @@ -0,0 +1,19 @@ +#ifndef STAR_NET_ELEMENT_SYSTEM_HPP +#define STAR_NET_ELEMENT_SYSTEM_HPP + +#include "StarNetElementBasicFields.hpp" +#include "StarNetElementFloatFields.hpp" +#include "StarNetElementSyncGroup.hpp" +#include "StarNetElementDynamicGroup.hpp" +#include "StarNetElementContainers.hpp" +#include "StarNetElementSignal.hpp" +#include "StarNetElementTop.hpp" + +namespace Star { + +// Makes a good default top-level NetElement group. +typedef NetElementTop<NetElementCallbackGroup> NetElementTopGroup; + +} + +#endif diff --git a/source/core/StarNetElementTop.hpp b/source/core/StarNetElementTop.hpp new file mode 100644 index 0000000..fa85688 --- /dev/null +++ b/source/core/StarNetElementTop.hpp @@ -0,0 +1,82 @@ +#ifndef STAR_NET_ELEMENT_TOP_HPP +#define STAR_NET_ELEMENT_TOP_HPP + +#include "StarNetElement.hpp" + +namespace Star { + +// Mixin for the NetElement that should be the top element for a network, wraps +// any NetElement class and manages the NetElementVersion. +template <typename BaseNetElement> +class NetElementTop : public BaseNetElement { +public: + NetElementTop(); + + // Returns the state update, combined with the version code that should be + // passed to the next call to writeState. If 'fromVersion' is 0, then this + // is a full write for an initial read of a slave NetElementTop. + pair<ByteArray, uint64_t> writeNetState(uint64_t fromVersion = 0); + // Reads a state produced by a call to writeState, optionally with the + // interpolation delay time for the data contained in this state update. If + // the state is a full update rather than a delta, the interoplation delay + // will be ignored. Blank updates are not necessary to send to be read by + // readState, *unless* extrapolation is enabled. If extrapolation is + // enabled, reading a blank update calls 'blankNetDelta' which is necessary + // to not improperly extrapolate past the end of incoming deltas. + void readNetState(ByteArray data, float interpolationTime = 0.0); + +private: + using BaseNetElement::initNetVersion; + using BaseNetElement::netStore; + using BaseNetElement::netLoad; + using BaseNetElement::writeNetDelta; + using BaseNetElement::readNetDelta; + using BaseNetElement::blankNetDelta; + + NetElementVersion m_netVersion; +}; + +template <typename BaseNetElement> +NetElementTop<BaseNetElement>::NetElementTop() { + BaseNetElement::initNetVersion(&m_netVersion); +} + +template <typename BaseNetElement> +pair<ByteArray, uint64_t> NetElementTop<BaseNetElement>::writeNetState(uint64_t fromVersion) { + if (fromVersion == 0) { + DataStreamBuffer ds; + ds.write<bool>(true); + BaseNetElement::netStore(ds); + m_netVersion.increment(); + return {ds.takeData(), m_netVersion.current()}; + + } else { + DataStreamBuffer ds; + ds.write<bool>(false); + if (!BaseNetElement::writeNetDelta(ds, fromVersion)) { + return {ByteArray(), m_netVersion.current()}; + } else { + m_netVersion.increment(); + return {ds.takeData(), m_netVersion.current()}; + } + } +} + +template <typename BaseNetElement> +void NetElementTop<BaseNetElement>::readNetState(ByteArray data, float interpolationTime) { + if (data.empty()) { + BaseNetElement::blankNetDelta(interpolationTime); + + } else { + DataStreamBuffer ds(move(data)); + + if (ds.read<bool>()) + BaseNetElement::netLoad(ds); + else + BaseNetElement::readNetDelta(ds, interpolationTime); + } +} + +} + +#endif diff --git a/source/core/StarNetImpl.hpp b/source/core/StarNetImpl.hpp new file mode 100644 index 0000000..80d75da --- /dev/null +++ b/source/core/StarNetImpl.hpp @@ -0,0 +1,157 @@ +#ifdef STAR_SYSTEM_FAMILY_WINDOWS +#include <winsock2.h> +#include <ws2tcpip.h> +#include <stdio.h> +#else +#ifdef STAR_SYSTEM_FREEBSD +#include <sys/types.h> +#include <sys/socket.h> +#endif +#include <errno.h> +#include <string.h> +#include <arpa/inet.h> +#include <netdb.h> +#include <netinet/in.h> +#include <netinet/udp.h> +#include <netinet/tcp.h> +#include <unistd.h> +#include <fcntl.h> +#include <poll.h> +#endif + +#include "StarHostAddress.hpp" + +#ifndef AI_ADDRCONFIG +#define AI_ADDRCONFIG 0 +#endif + +namespace Star { + +#ifdef STAR_SYSTEM_FAMILY_WINDOWS +struct WindowsSocketInitializer { + WindowsSocketInitializer() { + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + fatalError("WSAStartup failed", false); + }; +}; +static WindowsSocketInitializer g_windowsSocketInitializer; +#endif + +inline String netErrorString() { +#ifdef STAR_SYSTEM_WINDOWS + LPVOID lpMsgBuf = NULL; + + FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + WSAGetLastError(), + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language + (LPTSTR)&lpMsgBuf, + 0, + NULL); + + String result = String((char*)lpMsgBuf); + + if (lpMsgBuf != NULL) + LocalFree(lpMsgBuf); + + return result; +#else + return strf("%s - %s", errno, strerror(errno)); +#endif +} + +inline bool netErrorConnectionReset() { +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + return WSAGetLastError() == WSAECONNRESET || WSAGetLastError() == WSAENETRESET; +#else + return errno == ECONNRESET || errno == ETIMEDOUT; +#endif +} + +inline bool netErrorInterrupt() { +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + return WSAGetLastError() == WSAEINTR || WSAGetLastError() == WSAEWOULDBLOCK; +#else + return errno == EAGAIN || errno == EINTR || errno == EWOULDBLOCK; +#endif +} + +inline void setAddressFromNative(HostAddressWithPort& addressWithPort, NetworkMode mode, struct sockaddr_storage* sockAddr) { + switch (mode) { + case NetworkMode::IPv4: { + struct sockaddr_in* addr4 = (struct sockaddr_in*)sockAddr; + addressWithPort = HostAddressWithPort(mode, (uint8_t*)&(addr4->sin_addr.s_addr), ntohs(addr4->sin_port)); + break; + } + case NetworkMode::IPv6: { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)sockAddr; + addressWithPort = HostAddressWithPort(mode, (uint8_t*)&addr6->sin6_addr.s6_addr, ntohs(addr6->sin6_port)); + break; + } + default: + throw NetworkException("Invalid network mode for setAddressFromNative"); + } +} + +inline void setNativeFromAddress(HostAddressWithPort const& addressWithPort, struct sockaddr_storage* sockAddr, socklen_t* sockAddrLen) { + switch (addressWithPort.address().mode()) { + case NetworkMode::IPv4: { + struct sockaddr_in* addr4 = (struct sockaddr_in*)sockAddr; + *sockAddrLen = sizeof(*addr4); + + memset(addr4, 0, *sockAddrLen); + addr4->sin_family = AF_INET; + addr4->sin_port = htons(addressWithPort.port()); + memcpy(((char*)&addr4->sin_addr.s_addr), addressWithPort.address().bytes(), addressWithPort.address().size()); + + break; + } + case NetworkMode::IPv6: { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)sockAddr; + *sockAddrLen = sizeof(*addr6); + + memset(addr6, 0, *sockAddrLen); + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(addressWithPort.port()); + memcpy(((char*)&addr6->sin6_addr.s6_addr), addressWithPort.address().bytes(), addressWithPort.address().size()); + break; + } + default: + throw NetworkException("Invalid network mode for setNativeFromAddress"); + } +} + +#ifdef STAR_SYSTEM_FAMILY_WINDOWS +inline bool invalidSocketDescriptor(SOCKET socket) { + return socket == INVALID_SOCKET; +} +#else +inline bool invalidSocketDescriptor(int socket) { + return socket < 0; +} +#endif + +struct SocketImpl { + SocketImpl() { + socketDesc = 0; + } + + void setSockOpt(int level, int optname, const void* optval, socklen_t len) { +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + int ret = ::setsockopt(socketDesc, level, optname, (const char*)optval, len); +#else + int ret = ::setsockopt(socketDesc, level, optname, optval, len); +#endif + if (ret < 0) + throw NetworkException(strf("setSockOpt failed to set %d, %d: %s", level, optname, netErrorString())); + } + +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + SOCKET socketDesc; +#else + int socketDesc; +#endif +}; + +} diff --git a/source/core/StarObserverStream.hpp b/source/core/StarObserverStream.hpp new file mode 100644 index 0000000..c97376c --- /dev/null +++ b/source/core/StarObserverStream.hpp @@ -0,0 +1,98 @@ +#ifndef STAR_OBSERVER_STREAM_HPP +#define STAR_OBSERVER_STREAM_HPP + +#include "StarList.hpp" + +namespace Star { + +// Holds a stream of values which separate observers can query and track +// occurrences in the stream without pulling them from the stream. Each +// addition to the stream is given an abstract step value, and queries to the +// stream can reference a given step value in order to track events since the +// last query. +template <typename T> +class ObserverStream { +public: + ObserverStream(uint64_t historyLimit = 0); + + // If a history limit is set, then any entries with step values older than + // the given limit will be discarded automatically. A historyLimit of 0 + // means that no values will be forgotten. The step value increases by one + // with each entry added, or can be increased artificially by a call to + // tickStep. + uint64_t historyLimit() const; + void setHistoryLimit(uint64_t historyLimit = 0); + + // Add a value to the end of the stream and increment the step value by 1. + void add(T value); + + // Artificially tick the step by the given delta, which can be used to clear + // older values. + void tick(uint64_t delta = 1); + + // Query values in the stream since the given step value. Will return the + // values in the stream, and a new since value to pass to query on the next + // call. + pair<List<T>, uint64_t> query(uint64_t since = 0) const; + + // Resets the step value to 0 and clears all values. + void reset(); + +private: + uint64_t m_historyLimit; + uint64_t m_nextStep; + Deque<pair<uint64_t, T>> m_values; +}; + +template <typename T> +ObserverStream<T>::ObserverStream(uint64_t historyLimit) + : m_historyLimit(historyLimit), m_nextStep(0) {} + +template <typename T> +uint64_t ObserverStream<T>::historyLimit() const { + return m_historyLimit; +} + +template <typename T> +void ObserverStream<T>::setHistoryLimit(uint64_t historyLimit) { + m_historyLimit = historyLimit; + tick(0); +} + +template <typename T> +void ObserverStream<T>::add(T value) { + m_values.append({m_nextStep, move(value)}); + tick(1); +} + +template <typename T> +void ObserverStream<T>::tick(uint64_t delta) { + m_nextStep += delta; + uint64_t removeBefore = m_nextStep - min(m_nextStep, m_historyLimit); + while (!m_values.empty() && m_values.first().first < removeBefore) + m_values.removeFirst(); +} + +template <typename T> +pair<List<T>, uint64_t> ObserverStream<T>::query(uint64_t since) const { + List<T> res; + auto i = std::lower_bound(m_values.begin(), + m_values.end(), + since, + [](pair<uint64_t, T> const& p, uint64_t step) { return p.first < step; }); + while (i != m_values.end()) { + res.append(i->second); + ++i; + } + return {res, m_nextStep}; +} + +template <typename T> +void ObserverStream<T>::reset() { + m_nextStep = 0; + m_values.clear(); +} + +} + +#endif diff --git a/source/core/StarOptionParser.cpp b/source/core/StarOptionParser.cpp new file mode 100644 index 0000000..f9d51ff --- /dev/null +++ b/source/core/StarOptionParser.cpp @@ -0,0 +1,162 @@ +#include "StarOptionParser.hpp" +#include "StarIterator.hpp" + +namespace Star { + +void OptionParser::setCommandName(String commandName) { + m_commandName = move(commandName); +} + +void OptionParser::setSummary(String summary) { + m_summary = move(summary); +} + +void OptionParser::setAdditionalHelp(String help) { + m_additionalHelp = move(help); +} + +void OptionParser::addSwitch(String const& flag, String description) { + if (!m_options.insert(flag, Switch{flag, move(description)}).second) + throw OptionParserException::format("Duplicate switch '-%s' added", flag); +} + +void OptionParser::addParameter(String const& flag, String argument, RequirementMode requirementMode, String description) { + if (!m_options.insert(flag, Parameter{flag, move(argument), requirementMode, move(description)}).second) + throw OptionParserException::format("Duplicate flag '-%s' added", flag); +} + +void OptionParser::addArgument(String argument, RequirementMode requirementMode, String description) { + m_arguments.append(Argument{move(argument), requirementMode, move(description)}); +} + +pair<OptionParser::Options, StringList> OptionParser::parseOptions(StringList const& arguments) const { + Options result; + StringList errors; + bool endOfFlags = false; + + auto it = makeSIterator(arguments); + while (it.hasNext()) { + auto const& arg = it.next(); + if (arg == "--") { + endOfFlags = true; + continue; + } + + if (!endOfFlags && arg.beginsWith("-")) { + String flag = arg.substr(1); + auto option = m_options.maybe(flag); + if (!option) { + errors.append(strf("No such option '-%s'", flag)); + continue; + } + + if (option->is<Switch>()) { + result.switches.add(move(flag)); + } else { + auto const& parameter = option->get<Parameter>(); + if (!it.hasNext()) { + errors.append(strf("Option '-%s' must be followed by an argument", flag)); + continue; + } + String val = it.next(); + if (parameter.requirementMode != Multiple && result.parameters.contains(flag)) { + errors.append(strf("Option with argument '-%s' specified multiple times", flag)); + continue; + } + result.parameters[move(flag)].append(move(val)); + } + + } else { + result.arguments.append(move(arg)); + } + } + + for (auto const& pair : m_options) { + if (pair.second.is<Parameter>()) { + auto const& na = pair.second.get<Parameter>(); + if (na.requirementMode == Required && !result.parameters.contains(pair.first)) + errors.append(strf("Missing required flag with argument '-%s'", pair.first)); + } + } + + size_t minimumArguments = 0; + size_t maximumArguments = 0; + for (auto const& argument : m_arguments) { + if ((argument.requirementMode == Optional || argument.requirementMode == Required) && maximumArguments != NPos) + ++maximumArguments; + if (argument.requirementMode == Required) + ++minimumArguments; + if (argument.requirementMode == Multiple) + maximumArguments = NPos; + } + if (result.arguments.size() < minimumArguments) + errors.append(strf( + "Too few positional arguments given, expected at least %s got %s", minimumArguments, result.arguments.size())); + if (result.arguments.size() > maximumArguments) + errors.append(strf( + "Too many positional arguments given, expected at most %s got %s", maximumArguments, result.arguments.size())); + + return {move(result), move(errors)}; +} + +void OptionParser::printHelp(std::ostream& os) const { + if (!m_commandName.empty() && !m_summary.empty()) + format(os, "%s: %s\n\n", m_commandName, m_summary); + else if (!m_commandName.empty()) + format(os, "%s:\n\n", m_commandName); + else if (!m_summary.empty()) + format(os, "%s\n\n", m_summary); + + String cmdLineText; + + for (auto const& p : m_options) { + if (p.second.is<Switch>()) { + cmdLineText += strf(" [-%s]", p.first); + } else { + auto const& parameter = p.second.get<Parameter>(); + if (parameter.requirementMode == Optional) + cmdLineText += strf(" [-%s <%s>]", parameter.flag, parameter.argument); + else if (parameter.requirementMode == Required) + cmdLineText += strf(" -%s <%s>", parameter.flag, parameter.argument); + else if (parameter.requirementMode == Multiple) + cmdLineText += strf(" [-%s <%s>]...", parameter.flag, parameter.argument); + } + } + + for (auto const& p : m_arguments) { + if (p.requirementMode == Optional) + cmdLineText += strf(" [<%s>]", p.argumentName); + else if (p.requirementMode == Required) + cmdLineText += strf(" <%s>", p.argumentName); + else + cmdLineText += strf(" [<%s>...]", p.argumentName); + } + + if (m_commandName.empty()) + format(os, "Command Line Usage:%s\n", cmdLineText); + else + format(os, "Command Line Usage: %s%s\n", m_commandName, cmdLineText); + + for (auto const& p : m_options) { + if (p.second.is<Switch>()) { + auto const& sw = p.second.get<Switch>(); + if (!sw.description.empty()) + format(os, " -%s\t- %s\n", sw.flag, sw.description); + } + if (p.second.is<Parameter>()) { + auto const& parameter = p.second.get<Parameter>(); + if (!parameter.description.empty()) + format(os, " -%s <%s>\t- %s\n", parameter.flag, parameter.argument, parameter.description); + } + } + + for (auto const& p : m_arguments) { + if (!p.description.empty()) + format(os, " <%s>\t- %s\n", p.argumentName, p.description); + } + + if (!m_additionalHelp.empty()) + format(os, "\n%s\n", m_additionalHelp); +} + +} diff --git a/source/core/StarOptionParser.hpp b/source/core/StarOptionParser.hpp new file mode 100644 index 0000000..b17b040 --- /dev/null +++ b/source/core/StarOptionParser.hpp @@ -0,0 +1,84 @@ +#ifndef STAR_OPTION_PARSER_HPP +#define STAR_OPTION_PARSER_HPP + +#include "StarString.hpp" +#include "StarVariant.hpp" +#include "StarOrderedMap.hpp" +#include "StarOrderedSet.hpp" + +namespace Star { + +STAR_EXCEPTION(OptionParserException, StarException); + +// Simple command line argument parsing and help printing, only simple single +// dash flags are supported, no flag combining is allowed and all components +// must be separated by a space. +// +// A 'flag' here refers to a component that is preceded by a dash, like -f or +// -quiet. +// +// Three kinds of things are parsed: +// - 'switches' which are flags that do not have a value, like `-q` for quiet +// - 'parameters' are flags with a value that follows, like `-mode full` +// - 'arguments' are everything else, sorted positionally +class OptionParser { +public: + enum RequirementMode { + Optional, + Required, + Multiple + }; + + struct Options { + OrderedSet<String> switches; + StringMap<StringList> parameters; + StringList arguments; + }; + + void setCommandName(String commandName); + void setSummary(String summary); + void setAdditionalHelp(String help); + + void addSwitch(String const& flag, String description = {}); + void addParameter(String const& flag, String argumentName, RequirementMode mode, String description = {}); + void addArgument(String argumentName, RequirementMode mode, String description = {}); + + // Parse the given arguments into an options set, returns the options parsed + // and a list of all the errors encountered while parsing. + pair<Options, StringList> parseOptions(StringList const& arguments) const; + + // Print help text to the given std::ostream + void printHelp(std::ostream& os) const; + +private: + struct Switch { + String flag; + String description; + }; + + struct Parameter { + String flag; + String argument; + RequirementMode requirementMode; + String description; + }; + + struct Argument { + String argumentName; + RequirementMode requirementMode; + String description; + }; + + typedef Variant<Switch, Parameter> Option; + + String m_commandName; + String m_summary; + String m_additionalHelp; + + OrderedHashMap<String, Option> m_options; + List<Argument> m_arguments; +}; + +} + +#endif diff --git a/source/core/StarOrderedMap.hpp b/source/core/StarOrderedMap.hpp new file mode 100644 index 0000000..61f4c58 --- /dev/null +++ b/source/core/StarOrderedMap.hpp @@ -0,0 +1,657 @@ +#ifndef STAR_ORDERED_MAP_HPP +#define STAR_ORDERED_MAP_HPP + +#include "StarMap.hpp" + +namespace Star { + +// Wraps a normal map type and provides an element order independent of the +// underlying map order. +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +class OrderedMapWrapper { +public: + typedef Key key_type; + typedef Value mapped_type; + typedef pair<key_type const, mapped_type> value_type; + + typedef LinkedList<value_type, Allocator> OrderType; + typedef Map< + std::reference_wrapper<key_type const>, typename OrderType::iterator, MapArgs..., + typename Allocator::template rebind<pair<std::reference_wrapper<key_type const> const, typename OrderType::iterator>>::other + > MapType; + + typedef typename OrderType::iterator iterator; + typedef typename OrderType::const_iterator const_iterator; + + typedef typename OrderType::reverse_iterator reverse_iterator; + typedef typename OrderType::const_reverse_iterator const_reverse_iterator; + + typedef typename std::decay<mapped_type>::type* mapped_ptr; + typedef typename std::decay<mapped_type>::type const* mapped_const_ptr; + + template <typename Collection> + static OrderedMapWrapper from(Collection const& c); + + OrderedMapWrapper(); + + OrderedMapWrapper(OrderedMapWrapper const& map); + + template <typename InputIterator> + OrderedMapWrapper(InputIterator beg, InputIterator end); + + OrderedMapWrapper(initializer_list<value_type> list); + + List<key_type> keys() const; + List<mapped_type> values() const; + List<pair<key_type, mapped_type>> pairs() const; + + bool contains(key_type const& k) const; + + // Throws MapException if key not found + mapped_type& get(key_type const& k); + mapped_type const& get(key_type const& k) const; + + // Return def if key not found + mapped_type value(key_type const& k, mapped_type d = mapped_type()) const; + + Maybe<mapped_type> maybe(key_type const& k) const; + + mapped_const_ptr ptr(key_type const& k) const; + mapped_ptr ptr(key_type const& k); + + mapped_type& operator[](key_type const& k); + + OrderedMapWrapper& operator=(OrderedMapWrapper const& map); + + bool operator==(OrderedMapWrapper const& m) const; + + // Finds first value matching the given value and returns its key, throws + // MapException if no such value is found. + key_type keyOf(mapped_type const& v) const; + + // Finds all of the values matching the given value and returns their keys. + List<key_type> keysOf(mapped_type const& v) const; + + pair<iterator, bool> insert(value_type const& v); + pair<iterator, bool> insert(key_type k, mapped_type v); + + pair<iterator, bool> insertFront(value_type const& v); + pair<iterator, bool> insertFront(key_type k, mapped_type v); + + // Add a key / value pair, throw if the key already exists + mapped_type& add(key_type k, mapped_type v); + + // Set a key to a value, always override if it already exists + mapped_type& set(key_type k, mapped_type v); + + // Appends all values of given map into this map. If overwite is false, then + // skips values that already exist in this map. Returns false if any keys + // previously existed. + bool merge(OrderedMapWrapper const& m, bool overwrite = false); + + // Removes the item with key k and returns true if found, false otherwise. + bool remove(key_type const& k); + + // Remove and return the value with the key k, throws MapException if not + // found. + mapped_type take(key_type const& k); + + Maybe<value_type> maybeTake(key_type const& k); + + const_iterator begin() const; + const_iterator end() const; + + iterator begin(); + iterator end(); + + const_reverse_iterator rbegin() const; + const_reverse_iterator rend() const; + + reverse_iterator rbegin(); + reverse_iterator rend(); + + size_t size() const; + + iterator erase(iterator i); + size_t erase(key_type const& k); + + iterator find(key_type const& k); + const_iterator find(key_type const& k) const; + + Maybe<size_t> indexOf(key_type const& k) const; + + key_type const& keyAt(size_t i) const; + mapped_type const& valueAt(size_t i) const; + mapped_type& valueAt(size_t i); + + value_type takeFirst(); + void removeFirst(); + + value_type const& first() const; + + key_type const& firstKey() const; + mapped_type& firstValue(); + mapped_type const& firstValue() const; + + iterator insert(iterator pos, value_type v); + + void clear(); + + bool empty() const; + + iterator toBack(iterator i); + void toBack(key_type const& k); + + iterator toFront(iterator i); + void toFront(key_type const& k); + + template <typename Compare> + void sort(Compare comp); + + void sortByKey(); + void sortByValue(); + +private: + MapType m_map; + OrderType m_order; +}; + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +std::ostream& operator<<(std::ostream& os, OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...> const& m); + +template <typename Key, typename Value, typename Compare = std::less<Key>, typename Allocator = std::allocator<pair<Key const, Value>>> +using OrderedMap = OrderedMapWrapper<std::map, Key, Value, Allocator, Compare>; + +template <typename Key, typename Value, typename Hash = Star::hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = std::allocator<pair<Key const, Value>>> +using OrderedHashMap = OrderedMapWrapper<FlatHashMap, Key, Value, Allocator, Hash, Equals>; + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +template <typename Collection> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::from(Collection const& c) -> OrderedMapWrapper { + return OrderedMapWrapper(c.begin(), c.end()); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::OrderedMapWrapper() {} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::OrderedMapWrapper(OrderedMapWrapper const& map) { + for (auto const& p : map) + insert(p); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +template <typename InputIterator> +OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::OrderedMapWrapper(InputIterator beg, InputIterator end) { + while (beg != end) { + insert(*beg); + ++beg; + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::OrderedMapWrapper(initializer_list<value_type> list) { + for (value_type v : list) + insert(move(v)); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::keys() const -> List<key_type> { + List<key_type> keys; + for (auto const& p : *this) + keys.append(p.first); + return keys; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::values() const -> List<mapped_type> { + List<mapped_type> values; + for (auto const& p : *this) + values.append(p.second); + return values; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::pairs() const -> List<pair<key_type, mapped_type>> { + List<pair<key_type, mapped_type>> plist; + for (auto const& p : *this) + plist.append(p.second); + return plist; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +bool OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::contains(key_type const& k) const { + return m_map.find(k) != m_map.end(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::get(key_type const& k) -> mapped_type& { + auto i = m_map.find(k); + if (i == m_map.end()) + throw MapException(strf("Key '%s' not found in OrderedMap::get()", outputAny(k))); + + return i->second->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::get(key_type const& k) const -> mapped_type const& { + return const_cast<OrderedMapWrapper*>(this)->get(k); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::value(key_type const& k, mapped_type d) const -> mapped_type { + auto i = m_map.find(k); + if (i == m_map.end()) + return move(d); + else + return i->second->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::maybe(key_type const& k) const -> Maybe<mapped_type> { + auto i = find(k); + if (i == end()) + return {}; + else + return i->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::ptr(key_type const& k) const -> mapped_const_ptr { + auto i = find(k); + if (i == end()) + return nullptr; + else + return &i->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::ptr(key_type const& k) -> mapped_ptr { + iterator i = find(k); + if (i == end()) + return nullptr; + else + return &i->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::operator[](key_type const& k) -> mapped_type& { + auto i = m_map.find(k); + if (i == m_map.end()) { + iterator orderIt = m_order.insert(m_order.end(), value_type(k, mapped_type())); + i = m_map.insert(typename MapType::value_type(std::cref(orderIt->first), orderIt)).first; + return orderIt->second; + } else { + return i->second->second; + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::operator=(OrderedMapWrapper const& map) -> OrderedMapWrapper& { + if (this != &map) { + clear(); + for (auto const& p : map) + insert(p); + } + + return *this; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +bool OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::operator==(OrderedMapWrapper const& m) const { + return this == &m || mapsEqual(*this, m); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::keyOf(mapped_type const& v) const -> key_type { + for (const_iterator i = begin(); i != end(); ++i) { + if (i->second == v) + return i->first; + } + throw MapException(strf("Value '%s' not found in OrderedMap::keyOf()", outputAny(v))); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::keysOf(mapped_type const& v) const -> List<key_type> { + List<key_type> keys; + for (iterator i = begin(); i != end(); ++i) { + if (i->second == v) + keys.append(i->first); + } + return keys; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::insert(value_type const& v) -> pair<iterator, bool> { + auto i = m_map.find(v.first); + if (i == m_map.end()) { + iterator orderIt = m_order.insert(m_order.end(), v); + m_map.insert(i, typename MapType::value_type(std::cref(orderIt->first), orderIt)); + return std::make_pair(orderIt, true); + } else { + return std::make_pair(i->second, false); + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::insert(key_type k, mapped_type v) -> pair<iterator, bool> { + return insert(value_type(move(k), move(v))); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::insertFront(value_type const& v) -> pair<iterator, bool> { + auto i = m_map.find(v.first); + if (i == m_map.end()) { + iterator orderIt = m_order.insert(m_order.begin(), v); + m_map.insert(i, typename MapType::value_type(std::cref(orderIt->first), orderIt)); + return std::make_pair(orderIt, true); + } else { + return std::make_pair(i->second, false); + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::insertFront(key_type k, mapped_type v) -> pair<iterator, bool> { + return insertFront(value_type(move(k), move(v))); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::add(key_type k, mapped_type v) -> mapped_type& { + auto pair = insert(value_type(move(k), move(v))); + if (!pair.second) + throw MapException(strf("Entry with key '%s' already present.", outputAny(k))); + else + return pair.first->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::set(key_type k, mapped_type v) -> mapped_type& { + auto i = find(k); + if (i != end()) { + i->second = move(v); + return i->second; + } else { + return insert(value_type(move(k), move(v))).first->second; + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +bool OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::merge(OrderedMapWrapper const& m, bool overwrite) { + return mapMerge(*this, m, overwrite); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +bool OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::remove(key_type const& k) { + auto i = m_map.find(k); + if (i != m_map.end()) { + auto orderIt = i->second; + m_map.erase(i); + + m_order.erase(orderIt); + return true; + } else { + return false; + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::take(key_type const& k) -> mapped_type { + auto i = m_map.find(k); + if (i != m_map.end()) { + auto orderIt = i->second; + m_map.erase(i); + + mapped_type v = orderIt->second; + m_order.erase(i->second); + return v; + } else { + throw MapException(strf("Key '%s' not found in OrderedMap::take()", outputAny(k))); + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::maybeTake(key_type const& k) -> Maybe<value_type> { + iterator i = find(k); + if (i != end()) { + value_type v = *i; + erase(i); + return v; + } + + return {}; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::begin() const -> const_iterator { + return m_order.begin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::end() const -> const_iterator { + return m_order.end(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::begin() -> iterator { + return m_order.begin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::end() -> iterator { + return m_order.end(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::rbegin() const -> const_reverse_iterator { + return m_order.rbegin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::rend() const -> const_reverse_iterator { + return m_order.rend(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::rbegin() -> reverse_iterator { + return m_order.rbegin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::rend() -> reverse_iterator { + return m_order.rend(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::size() const -> size_t { + return m_map.size(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::erase(iterator i) -> iterator { + m_map.erase(i->first); + return m_order.erase(i); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::erase(key_type const& k) -> size_t { + if (remove(k)) + return 1; + return 0; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::find(key_type const& k) -> iterator { + auto i = m_map.find(k); + if (i == m_map.end()) + return m_order.end(); + else + return i->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::find(key_type const& k) const -> const_iterator { + auto i = m_map.find(k); + if (i == m_map.end()) + return m_order.end(); + else + return i->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::indexOf(key_type const& k) const -> Maybe<size_t> { + typename MapType::const_iterator i = m_map.find(k); + if (i == m_map.end()) + return {}; + + return std::distance(begin(), const_iterator(i->second)); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::keyAt(size_t i) const -> key_type const& { + if (i >= size()) + throw MapException(strf("index %s out of range in OrderedMap::at()", i)); + + auto it = begin(); + std::advance(it, i); + return it->first; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::valueAt(size_t i) const -> mapped_type const& { + return const_cast<OrderedMapWrapper*>(this)->valueAt(i); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::valueAt(size_t i) -> mapped_type& { + if (i >= size()) + throw MapException(strf("index %s out of range in OrderedMap::valueAt()", i)); + + auto it = m_order.begin(); + std::advance(it, i); + return it->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::takeFirst() -> value_type { + if (empty()) + throw MapException("OrderedMap::takeFirst() called on empty OrderedMap"); + + iterator i = m_order.begin(); + m_map.remove(i->first); + value_type v = *i; + m_order.erase(i); + return v; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::removeFirst() { + erase(begin()); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::first() const -> value_type const& { + if (empty()) + throw MapException("OrderedMap::takeFirst() called on empty OrderedMap"); + + return *m_order.begin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::firstValue() -> mapped_type& { + return begin()->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::firstValue() const -> mapped_type const& { + return begin()->second; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::firstKey() const -> key_type const& { + return begin()->first; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::insert(iterator pos, value_type v) -> iterator { + auto i = m_map.find(v.first); + if (i == m_map.end()) { + iterator orderIt = m_order.insert(pos, move(v)); + m_map.insert(typename MapType::value_type(std::cref(orderIt->first), orderIt)); + return orderIt; + } else { + i->second->second = move(v.second); + m_order.splice(pos, m_order, i->second); + return i->second; + } +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::clear() { + m_map.clear(); + m_order.clear(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +bool OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::empty() const { + return size() == 0; +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::toBack(iterator i) -> iterator { + m_order.splice(m_order.end(), m_order, i); + return prev(m_order.end()); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +auto OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::toFront(iterator i) -> iterator { + m_order.splice(m_order.begin(), m_order, i); + return m_order.begin(); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::toBack(key_type const& k) { + auto i = m_map.find(k); + if (i == m_map.end()) + throw MapException(strf("Key not found in OrderedMap::toBack('%s')", outputAny(k))); + + toBack(i->second); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::toFront(key_type const& k) { + auto i = m_map.find(k); + if (i == m_map.end()) + throw MapException(strf("Key not found in OrderedMap::toFront('%s')", outputAny(k))); + + toFront(i->second); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +template <typename Compare> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::sort(Compare comp) { + m_order.sort(comp); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::sortByKey() { + sort([](value_type const& a, value_type const& b) { + return a.first < b.first; + }); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +void OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...>::sortByValue() { + sort([](value_type const& a, value_type const& b) { + return a.second < b.second; + }); +} + +template <template <typename...> class Map, typename Key, typename Value, typename Allocator, typename... MapArgs> +std::ostream& operator<<(std::ostream& os, OrderedMapWrapper<Map, Key, Value, Allocator, MapArgs...> const& m) { + printMap(os, m); + return os; +} + +} + +#endif diff --git a/source/core/StarOrderedSet.hpp b/source/core/StarOrderedSet.hpp new file mode 100644 index 0000000..319db79 --- /dev/null +++ b/source/core/StarOrderedSet.hpp @@ -0,0 +1,430 @@ +#ifndef STAR_ORDERED_SET_HPP +#define STAR_ORDERED_SET_HPP + +#include <map> + +#include "StarFlatHashMap.hpp" +#include "StarSet.hpp" +#include "StarList.hpp" + +namespace Star { + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +class OrderedSetWrapper { +public: + typedef Value value_type; + + typedef LinkedList<value_type, typename Allocator::template rebind<value_type>::other> OrderType; + typedef Map< + std::reference_wrapper<value_type const>, typename OrderType::const_iterator, Args..., + typename Allocator::template rebind<pair<std::reference_wrapper<value_type const> const, typename OrderType::const_iterator>>::other + > MapType; + + typedef typename OrderType::const_iterator const_iterator; + typedef const_iterator iterator; + + typedef typename OrderType::const_reverse_iterator const_reverse_iterator; + typedef const_reverse_iterator reverse_iterator; + + template <typename Collection> + static OrderedSetWrapper from(Collection const& c); + + OrderedSetWrapper(); + OrderedSetWrapper(OrderedSetWrapper const& set); + + template <typename InputIterator> + OrderedSetWrapper(InputIterator beg, InputIterator end); + + OrderedSetWrapper(initializer_list<value_type> list); + + OrderedSetWrapper& operator=(OrderedSetWrapper const& set); + + // Guaranteed to be in order. + List<value_type> values() const; + + bool contains(value_type const& v) const; + + // add either adds the value to the back, or does not move it from its + // current order. + pair<iterator, bool> insert(value_type const& v); + + // like insert, but only returns whether the value was added or not. + bool add(Value const& v); + + // Always replaces an existing value with a new value if it exists, and + // always moves to the back. + bool replace(Value const& v); + + // Either adds a value to the end of the order, or moves an existing value to + // the back. + bool addBack(Value const& v); + + // Either adds a value to the beginning of the order, or moves an existing + // value to the beginning. + bool addFront(Value const& v); + + template <typename Container> + void addAll(Container const& c); + + iterator toFront(iterator i); + + iterator toBack(iterator i); + + bool remove(value_type const& v); + + template <typename Container> + void removeAll(Container const& c); + + void clear(); + + value_type const& first() const; + value_type const& last() const; + + void removeFirst(); + void removeLast(); + + value_type takeFirst(); + value_type takeLast(); + + template <typename Compare> + void sort(Compare comp); + + void sort(); + + size_t empty() const; + size_t size() const; + + const_iterator begin() const; + const_iterator end() const; + + const_reverse_iterator rbegin() const; + const_reverse_iterator rend() const; + + Maybe<size_t> indexOf(value_type const& v) const; + + value_type const& at(size_t i) const; + value_type& at(size_t i); + + OrderedSetWrapper intersection(OrderedSetWrapper const& s) const; + OrderedSetWrapper difference(OrderedSetWrapper const& s) const; + +private: + MapType m_map; + OrderType m_order; +}; + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +std::ostream& operator<<(std::ostream& os, OrderedSetWrapper<Map, Value, Allocator, Args...> const& set); + +template <typename Value, typename Compare = std::less<Value>, typename Allocator = std::allocator<Value>> +using OrderedSet = OrderedSetWrapper<std::map, Value, Allocator, Compare>; + +template <typename Value, typename Hash = Star::hash<Value>, typename Equals = std::equal_to<Value>, typename Allocator = std::allocator<Value>> +using OrderedHashSet = OrderedSetWrapper<FlatHashMap, Value, Allocator, Hash, Equals>; + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +template <typename Collection> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::from(Collection const& c) -> OrderedSetWrapper { + return OrderedSetWrapper(c.begin(), c.end()); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +OrderedSetWrapper<Map, Value, Allocator, Args...>::OrderedSetWrapper() {} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +OrderedSetWrapper<Map, Value, Allocator, Args...>::OrderedSetWrapper(OrderedSetWrapper const& set) { + for (auto const& p : set) + add(p); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +template <typename InputIterator> +OrderedSetWrapper<Map, Value, Allocator, Args...>::OrderedSetWrapper(InputIterator beg, InputIterator end) { + while (beg != end) { + add(*beg); + ++beg; + } +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +OrderedSetWrapper<Map, Value, Allocator, Args...>::OrderedSetWrapper(initializer_list<value_type> list) { + for (value_type const& v : list) + add(v); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::operator=(OrderedSetWrapper const& set) -> OrderedSetWrapper& { + if (this != &set) { + clear(); + for (auto const& p : set) + add(p); + } + + return *this; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::values() const -> List<value_type> { + List<value_type> values; + for (auto p : *this) + values.append(move(p)); + return values; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::contains(value_type const& v) const { + return m_map.find(v) != m_map.end(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::insert(value_type const& v) -> pair<iterator, bool> { + auto i = m_map.find(v); + if (i == m_map.end()) { + auto orderIt = m_order.insert(m_order.end(), v); + m_map.insert(typename MapType::value_type(std::cref(*orderIt), orderIt)); + return {orderIt, true}; + } + return {i->second, false}; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::add(Value const& v) { + return insert(v).second; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::replace(Value const& v) { + bool replaced = remove(v); + add(v); + return replaced; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::addBack(Value const& v) { + auto i = m_map.find(v); + if (i != m_map.end()) { + m_order.splice(m_order.end(), m_order, i->second); + return false; + } else { + iterator orderIt = m_order.insert(m_order.end(), v); + m_map.insert(typename MapType::value_type(std::cref(*orderIt), orderIt)); + return true; + } +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::addFront(Value const& v) { + auto i = m_map.find(v); + if (i != m_map.end()) { + m_order.splice(m_order.begin(), m_order, i->second); + return false; + } else { + iterator orderIt = m_order.insert(m_order.end(), v); + m_map.insert(typename MapType::value_type(std::cref(*orderIt), orderIt)); + return true; + } +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +template <typename Container> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::addAll(Container const& c) { + for (auto const& v : c) + add(v); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::toFront(iterator i) -> iterator { + m_order.splice(m_order.begin(), m_order, i); + return m_order.begin(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::toBack(iterator i) -> iterator { + m_order.splice(m_order.end(), m_order, i); + return prev(m_order.end()); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +bool OrderedSetWrapper<Map, Value, Allocator, Args...>::remove(value_type const& v) { + auto i = m_map.find(v); + if (i != m_map.end()) { + auto orderIt = i->second; + m_map.erase(i); + m_order.erase(orderIt); + return true; + } + return false; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +template <typename Container> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::removeAll(Container const& c) { + for (auto const& v : c) + remove(v); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::clear() { + m_map.clear(); + m_order.clear(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::first() const -> value_type const& { + if (empty()) + throw SetException("first() called on empty OrderedSet"); + return *begin(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::last() const -> value_type const& { + if (empty()) + throw SetException("last() called on empty OrderedSet"); + return *(prev(end())); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::removeFirst() { + if (empty()) + throw SetException("OrderedSet::removeFirst() called on empty OrderedSet"); + + auto i = m_order.begin(); + m_map.erase(*i); + m_order.erase(i); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::removeLast() { + if (empty()) + throw SetException("OrderedSet::removeLast() called on empty OrderedSet"); + + auto i = m_order.end(); + --i; + m_map.erase(*i); + m_order.erase(i); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::takeFirst() -> value_type { + if (empty()) + throw SetException("OrderedSet::takeFirst() called on empty OrderedSet"); + + auto i = m_order.begin(); + m_map.erase(*i); + value_type v = *i; + m_order.erase(i); + return v; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::takeLast() -> value_type { + if (empty()) + throw SetException("OrderedSet::takeLast() called on empty OrderedSet"); + + auto i = m_order.end(); + --i; + m_map.erase(*i); + value_type v = *i; + m_order.erase(i); + return v; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +template <typename Compare> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::sort(Compare comp) { + m_order.sort(comp); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +void OrderedSetWrapper<Map, Value, Allocator, Args...>::sort() { + m_order.sort(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +size_t OrderedSetWrapper<Map, Value, Allocator, Args...>::empty() const { + return m_map.empty(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +size_t OrderedSetWrapper<Map, Value, Allocator, Args...>::size() const { + return m_map.size(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::begin() const -> const_iterator { + return m_order.begin(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::end() const -> const_iterator { + return m_order.end(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::rbegin() const -> const_reverse_iterator { + return m_order.rbegin(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::rend() const -> const_reverse_iterator { + return m_order.rend(); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +Maybe<size_t> OrderedSetWrapper<Map, Value, Allocator, Args...>::indexOf(value_type const& v) const { + auto i = m_map.find(v); + if (i == m_map.end()) + return {}; + + return std::distance(begin(), const_iterator(i->second)); +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::at(size_t i) const -> value_type const& { + auto it = begin(); + std::advance(it, i); + return *it; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::at(size_t i) -> value_type& { + auto it = begin(); + std::advance(it, i); + return *it; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::intersection(OrderedSetWrapper const& s) const -> OrderedSetWrapper { + OrderedSetWrapper ret; + for (auto const& e : s) { + if (contains(e)) + ret.add(e); + } + return ret; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +auto OrderedSetWrapper<Map, Value, Allocator, Args...>::difference(OrderedSetWrapper const& s) const -> OrderedSetWrapper { + OrderedSetWrapper ret; + for (auto const& e : *this) { + if (!s.contains(e)) + ret.add(e); + } + return ret; +} + +template <template <typename...> class Map, typename Value, typename Allocator, typename... Args> +std::ostream& operator<<(std::ostream& os, OrderedSetWrapper<Map, Value, Allocator, Args...> const& set) { + os << "("; + for (auto i = set.begin(); i != set.end(); ++i) { + if (i != set.begin()) + os << ", "; + os << *i; + } + os << ")"; + return os; +} + +} + +#endif diff --git a/source/core/StarParametricFunction.hpp b/source/core/StarParametricFunction.hpp new file mode 100644 index 0000000..5eaff14 --- /dev/null +++ b/source/core/StarParametricFunction.hpp @@ -0,0 +1,288 @@ +#ifndef STAR_PARAMETRIC_FUNCTION_HPP +#define STAR_PARAMETRIC_FUNCTION_HPP + +#include "StarInterpolation.hpp" + +namespace Star { + +// Describes a simple table from index to value, which operates on bins +// corresponding to ranges of indexes. IndexType can be any ordered type, +// ValueType can be anything. +template <typename IndexType, typename ValueType = IndexType> +class ParametricTable { +public: + typedef IndexType Index; + typedef ValueType Value; + + ParametricTable(); + + template <typename OtherIndexType, typename OtherValueType> + explicit ParametricTable(ParametricTable<OtherIndexType, OtherValueType> const& parametricFunction); + + // Construct a ParametricTable with a list of point pairs, which does not + // have to be sorted (it will be sorted internally). Throws an exception on + // duplicate index values. + template <typename PairContainer> + explicit ParametricTable(PairContainer indexValuePairs); + + // addPoint does not need to be called in order, it will insert the point in + // the correct ordered position for the given index, and return the position. + size_t addPoint(IndexType index, ValueType value); + void clearPoints(); + + size_t size() const; + bool empty() const; + + IndexType const& index(size_t i) const; + ValueType const& value(size_t i) const; + + // Returns true if the values of the table are also valid indexes (true when + // the data points are monotonic increasing) + bool isInvertible() const; + + // Invert the table, switching indexes and values. Throws an exception if + // the function is not invertible. Will not generally compile unless the + // Index and Value types are the same type. + void invert() const; + + // Find the value to the left of the given index. If the index is lower than + // the lowest index point, returns the first value. + ValueType const& get(IndexType index) const; + +protected: + typedef std::vector<IndexType> IndexList; + typedef std::vector<ValueType> ValueList; + + IndexList const& indexes() const; + ValueList const& values() const; + +private: + IndexList m_indexes; + ValueList m_values; +}; + +// Extension of ParametricTable that simplifies of all the complex +// interpolation code for interpolating an ordered list of points. Useful for +// describing a simple 2d or n-dimensional (using VectorN for value type) curve +// of one variable. IndexType should generally be float or double, and +// ValueType can be any type that can be interpolated. +template <typename IndexType, typename ValueType = IndexType> +class ParametricFunction : public ParametricTable<IndexType, ValueType> { +public: + typedef ParametricTable<IndexType, ValueType> Base; + + ParametricFunction( + InterpolationMode interpolationMode = InterpolationMode::Linear, BoundMode boundMode = BoundMode::Clamp); + + template <typename OtherIndexType, typename OtherValueType> + explicit ParametricFunction(ParametricFunction<OtherIndexType, OtherValueType> const& parametricFunction); + + template <typename PairContainer> + explicit ParametricFunction(PairContainer indexValuePairs, + InterpolationMode interpolationMode = InterpolationMode::Linear, + BoundMode boundMode = BoundMode::Clamp); + + InterpolationMode interpolationMode() const; + void setInterpolationMode(InterpolationMode interpolationType); + + BoundMode boundMode() const; + void setBoundMode(BoundMode boundMode); + + // Interpolates a value at the given index according to the interpolation and + // bound mode. + ValueType interpolate(IndexType index) const; + + // Synonym for interpolate + ValueType operator()(IndexType index) const; + +private: + InterpolationMode m_interpolationMode; + BoundMode m_boundMode; +}; + +template <typename IndexType, typename ValueType> +ParametricTable<IndexType, ValueType>::ParametricTable() {} + +template <typename IndexType, typename ValueType> +template <typename OtherIndexType, typename OtherValueType> +ParametricTable<IndexType, ValueType>::ParametricTable( + ParametricTable<OtherIndexType, OtherValueType> const& parametricTable) { + for (size_t i = 0; i < parametricTable.size(); ++i) { + m_indexes.push_back(parametricTable.index(i)); + m_values.push_back(parametricTable.value(i)); + } +} + +template <typename IndexType, typename ValueType> +template <typename PairContainer> +ParametricTable<IndexType, ValueType>::ParametricTable(PairContainer indexValuePairs) { + if (indexValuePairs.empty()) + return; + + sort(indexValuePairs, + [](typename PairContainer::value_type const& a, typename PairContainer::value_type const& b) { + return std::get<0>(a) < std::get<0>(b); + }); + + for (auto const pair : indexValuePairs) { + m_indexes.push_back(move(std::get<0>(pair))); + m_values.push_back(move(std::get<1>(pair))); + } + + for (size_t i = 0; i < size() - 1; ++i) { + if (m_indexes[i] == m_indexes[i + 1]) + throw MathException("Degenerate index values given in ParametricTable constructor"); + } +} + +template <typename IndexType, typename ValueType> +size_t ParametricTable<IndexType, ValueType>::addPoint(IndexType index, ValueType value) { + size_t insertLocation = std::distance(m_indexes.begin(), std::upper_bound(m_indexes.begin(), m_indexes.end(), index)); + m_indexes.insert(m_indexes.begin() + insertLocation, move(index)); + m_values.insert(m_values.begin() + insertLocation, move(value)); + return insertLocation; +} + +template <typename IndexType, typename ValueType> +void ParametricTable<IndexType, ValueType>::clearPoints() { + m_indexes.clear(); + m_values.clear(); +} + +template <typename IndexType, typename ValueType> +size_t ParametricTable<IndexType, ValueType>::size() const { + return m_indexes.size(); +} + +template <typename IndexType, typename ValueType> +bool ParametricTable<IndexType, ValueType>::empty() const { + return size() == 0; +} + +template <typename IndexType, typename ValueType> +IndexType const& ParametricTable<IndexType, ValueType>::index(size_t i) const { + return m_indexes.at(i); +} + +template <typename IndexType, typename ValueType> +ValueType const& ParametricTable<IndexType, ValueType>::value(size_t i) const { + return m_values.at(i); +} + +template <typename IndexType, typename ValueType> +bool ParametricTable<IndexType, ValueType>::isInvertible() const { + if (empty()) + return true; + + for (size_t i = 0; i < size() - 1; ++i) { + if (m_values[i] > m_values[i + 1]) + return false; + } + + return true; +} + +template <typename IndexType, typename ValueType> +void ParametricTable<IndexType, ValueType>::invert() const { + if (isInvertible()) + throw MathException("invert() called on non-invertible ParametricTable"); + + for (size_t i = 0; i < size(); ++i) + std::swap(m_indexes[i], m_values[i]); +} + +template <typename IndexType, typename ValueType> +ValueType const& ParametricTable<IndexType, ValueType>::get(IndexType index) const { + if (empty()) + throw MathException("get called on empty ParametricTable"); + + auto i = std::lower_bound(m_indexes.begin(), m_indexes.end(), index); + if (i != m_indexes.begin()) + --i; + + return m_values[std::distance(m_indexes.begin(), i)]; +} + +template <typename IndexType, typename ValueType> +auto ParametricTable<IndexType, ValueType>::indexes() const -> IndexList const & { + return m_indexes; +} + +template <typename IndexType, typename ValueType> +auto ParametricTable<IndexType, ValueType>::values() const -> ValueList const & { + return m_values; +} + +template <typename IndexType, typename ValueType> +ParametricFunction<IndexType, ValueType>::ParametricFunction(InterpolationMode interpolationMode, BoundMode boundMode) + : m_interpolationMode(interpolationMode), m_boundMode(boundMode) {} + +template <typename IndexType, typename ValueType> +template <typename OtherIndexType, typename OtherValueType> +ParametricFunction<IndexType, ValueType>::ParametricFunction( + ParametricFunction<OtherIndexType, OtherValueType> const& parametricFunction) + : Base(parametricFunction) { + m_interpolationMode = parametricFunction.interpolationMode(); + m_boundMode = parametricFunction.boundMode(); +} + +template <typename IndexType, typename ValueType> +template <typename PairContainer> +ParametricFunction<IndexType, ValueType>::ParametricFunction( + PairContainer indexValuePairs, InterpolationMode interpolationMode, BoundMode boundMode) + : Base(indexValuePairs) { + m_interpolationMode = interpolationMode; + m_boundMode = boundMode; +} + +template <typename IndexType, typename ValueType> +InterpolationMode ParametricFunction<IndexType, ValueType>::interpolationMode() const { + return m_interpolationMode; +} + +template <typename IndexType, typename ValueType> +void ParametricFunction<IndexType, ValueType>::setInterpolationMode(InterpolationMode interpolationType) { + m_interpolationMode = interpolationType; +} + +template <typename IndexType, typename ValueType> +BoundMode ParametricFunction<IndexType, ValueType>::boundMode() const { + return m_boundMode; +} + +template <typename IndexType, typename ValueType> +void ParametricFunction<IndexType, ValueType>::setBoundMode(BoundMode boundMode) { + m_boundMode = boundMode; +} + +template <typename IndexType, typename ValueType> +ValueType ParametricFunction<IndexType, ValueType>::interpolate(IndexType index) const { + if (Base::empty()) + return ValueType(); + + if (m_interpolationMode == InterpolationMode::HalfStep) { + return parametricInterpolate2(Base::indexes(), Base::values(), index, StepWeightOperator<IndexType>(), m_boundMode); + + } else if (m_interpolationMode == InterpolationMode::Linear) { + return parametricInterpolate2( + Base::indexes(), Base::values(), index, LinearWeightOperator<IndexType>(), m_boundMode); + + } else if (m_interpolationMode == InterpolationMode::Cubic) { + // ParametricFunction uses CubicWeights with linear extrapolation (not + // configurable atm) + return parametricInterpolate4( + Base::indexes(), Base::values(), index, Cubic4WeightOperator<IndexType>(true), m_boundMode); + + } else { + throw MathException("Unsupported interpolation type in ParametricFunction::interpolate"); + } +} + +template <typename IndexType, typename ValueType> +ValueType ParametricFunction<IndexType, ValueType>::operator()(IndexType index) const { + return interpolate(index); +} + +} + +#endif diff --git a/source/core/StarPeriodic.hpp b/source/core/StarPeriodic.hpp new file mode 100644 index 0000000..2cf24cc --- /dev/null +++ b/source/core/StarPeriodic.hpp @@ -0,0 +1,98 @@ +#ifndef STAR_PERIODIC_HPP +#define STAR_PERIODIC_HPP + +#include "StarMathCommon.hpp" +#include "StarRandom.hpp" + +namespace Star { + +// Perform some action every X ticks. Setting the tick count to 0 means never +// perform the action, 1 performs the action every call, 2 performs the action +// every other call, and so forth. +class Periodic { +public: + Periodic(unsigned everyXSteps = 1) + : m_counter(0), m_everyXSteps(everyXSteps) {} + + unsigned stepCount() const { + return m_everyXSteps; + } + + void setStepCount(unsigned everyXSteps) { + m_everyXSteps = everyXSteps; + if (everyXSteps != 0) + m_counter = clamp<unsigned>(m_counter, 0, m_everyXSteps - 1); + else + m_counter = 0; + } + + // Will the next tick() return true? + bool ready() const { + return m_everyXSteps != 0 && m_counter == 0; + } + + bool tick() { + if (m_everyXSteps == 0) + return false; + + if (m_counter == 0) { + m_counter = m_everyXSteps - 1; + return true; + } else { + --m_counter; + return false; + } + } + + template <typename Function> + void tick(Function&& function) { + if (tick()) + function(); + } + +private: + unsigned m_counter; + unsigned m_everyXSteps; +}; + +// Perform some action with a given period over an amount of some value (like +// time) with optional randomness. +class RatePeriodic { +public: + RatePeriodic(double period = 1, double noise = 0) + : m_period(period), m_noise(noise), m_counter(period + Random::randf(-noise, noise)), m_elapsed(0.0) {} + + double period() const { + return m_period; + } + + double noise() const { + return m_noise; + } + + template <typename Function> + void update(double amount, Function&& function) { + double subAmount = min(amount, m_counter); + m_counter -= subAmount; + amount -= subAmount; + m_elapsed += subAmount; + + if (m_counter <= 0.0) { + m_counter = m_period + Random::randf(-m_noise, m_noise); + function(m_elapsed); + m_elapsed = 0.0; + if (amount > 0.0) + update(amount, forward<Function>(function)); + } + } + +private: + double m_period; + double m_noise; + double m_counter; + double m_elapsed; +}; + +} + +#endif diff --git a/source/core/StarPeriodicFunction.hpp b/source/core/StarPeriodicFunction.hpp new file mode 100644 index 0000000..ee74808 --- /dev/null +++ b/source/core/StarPeriodicFunction.hpp @@ -0,0 +1,82 @@ +#ifndef STAR_PERIODIC_FUNCTION_HPP +#define STAR_PERIODIC_FUNCTION_HPP + +#include "StarInterpolation.hpp" +#include "StarRandom.hpp" + +namespace Star { + +// Repeating, periodic function with optional period and magnitude variance. +// Each cycle of the function will randomize the min and max values of the +// function by the magnitude variance, and the period by the period variance. +// Can approximate a randomized sin wave, triangle wave, square wave, etc based +// on the weight operator provided to the value method. +template <typename Float> +class PeriodicFunction { +public: + PeriodicFunction( + Float period = 1, Float min = 0, Float max = 1, Float periodVariance = 0, Float magnitudeVariance = 0); + + void update(Float delta); + + template <typename WeightOperator> + Float value(WeightOperator weightOperator) const; + +private: + Float m_halfPeriod; + Float m_min; + Float m_max; + Float m_halfPeriodVariance; + Float m_magnitudeVariance; + + Float m_timerMax; + Float m_timer; + Float m_source; + Float m_target; + bool m_targetMode; +}; + +template <typename Float> +PeriodicFunction<Float>::PeriodicFunction( + Float period, Float min, Float max, Float periodVariance, Float magnitudeVariance) { + m_halfPeriod = period / 2; + m_min = min; + m_max = max; + m_halfPeriodVariance = periodVariance / 2; + m_magnitudeVariance = magnitudeVariance; + + m_timerMax = m_halfPeriod; + m_timer = 0; + m_source = m_max + Random::randf(-1, 1) * m_magnitudeVariance; + m_target = m_min + Random::randf(-1, 1) * m_magnitudeVariance; + m_targetMode = true; +} + +template <typename Float> +void PeriodicFunction<Float>::update(Float delta) { + m_timer -= delta; + + // Only bring the timer forward once, rather than doing this in a loop. This + // makes the function behave somewhat differently than it would for deltas + // which are greater than the period, but it avoids infinite looping + if (m_timer <= 0.0f) { + m_source = m_target; + m_target = (m_targetMode ? m_max : m_min) + Random::randf(-1, 1) * m_magnitudeVariance; + m_targetMode = !m_targetMode; + m_timerMax = m_halfPeriod + Random::randf(-1, 1) * m_halfPeriodVariance; + m_timer = max(0.0f, m_timer + m_timerMax); + } +} + +template <typename Float> +template <typename WeightOperator> +Float PeriodicFunction<Float>::value(WeightOperator weightOperator) const { + // This is inverted, m_timer goes from m_timerMax to 0 as the value should go + // from m_source to m_target + auto wvec = weightOperator(m_timer / m_timerMax); + return m_target * wvec[0] + m_source * wvec[1]; +} + +} + +#endif diff --git a/source/core/StarPerlin.cpp b/source/core/StarPerlin.cpp new file mode 100644 index 0000000..6688a7a --- /dev/null +++ b/source/core/StarPerlin.cpp @@ -0,0 +1,12 @@ +#include "StarPerlin.hpp" + +namespace Star { + +EnumMap<PerlinType> const PerlinTypeNames{ + {PerlinType::Uninitialized, "uninitialized"}, + {PerlinType::Perlin, "perlin"}, + {PerlinType::Billow, "billow"}, + {PerlinType::RidgedMulti, "ridgedMulti"}, +}; + +} diff --git a/source/core/StarPerlin.hpp b/source/core/StarPerlin.hpp new file mode 100644 index 0000000..b042e47 --- /dev/null +++ b/source/core/StarPerlin.hpp @@ -0,0 +1,718 @@ +#ifndef STAR_PERLIN_HPP +#define STAR_PERLIN_HPP + +#include "StarJson.hpp" +#include "StarBiMap.hpp" +#include "StarInterpolation.hpp" +#include "StarRandom.hpp" + +namespace Star { + +STAR_EXCEPTION(PerlinException, StarException); + +enum class PerlinType { + Uninitialized, + Perlin, + Billow, + RidgedMulti +}; +extern EnumMap<PerlinType> const PerlinTypeNames; + +int const PerlinSampleSize = 512; + +template <typename Float> +class Perlin { +public: + // Default constructed perlin noise is uninitialized and cannot be queried. + Perlin(); + + Perlin(unsigned octaves, Float freq, Float amp, Float bias, Float alpha, Float beta, uint64_t seed); + Perlin(PerlinType type, unsigned octaves, Float freq, Float amp, Float bias, Float alpha, Float beta, uint64_t seed); + Perlin(Json const& config, uint64_t seed); + explicit Perlin(Json const& json); + + Perlin(Perlin const& perlin); + Perlin(Perlin&& perlin); + + Perlin& operator=(Perlin const& perlin); + Perlin& operator=(Perlin&& perlin); + + Float get(Float x) const; + Float get(Float x, Float y) const; + Float get(Float x, Float y, Float z) const; + + PerlinType type() const; + + unsigned octaves() const; + Float frequency() const; + Float amplitude() const; + Float bias() const; + Float alpha() const; + Float beta() const; + + Json toJson() const; + +private: + static Float s_curve(Float t); + static void setup(Float v, int& b0, int& b1, Float& r0, Float& r1); + + static Float at2(Float* q, Float rx, Float ry); + static Float at3(Float* q, Float rx, Float ry, Float rz); + + Float noise1(Float arg) const; + Float noise2(Float vec[2]) const; + Float noise3(Float vec[3]) const; + + void normalize2(Float v[2]) const; + void normalize3(Float v[3]) const; + + void init(uint64_t seed); + + Float perlin(Float x) const; + Float perlin(Float x, Float y) const; + Float perlin(Float x, Float y, Float z) const; + + Float ridgedMulti(Float x) const; + Float ridgedMulti(Float x, Float y) const; + Float ridgedMulti(Float x, Float y, Float z) const; + + Float billow(Float x) const; + Float billow(Float x, Float y) const; + Float billow(Float x, Float y, Float z) const; + + PerlinType m_type; + uint64_t m_seed; + + int m_octaves; + Float m_frequency; + Float m_amplitude; + Float m_bias; + Float m_alpha; + Float m_beta; + + // Only used for RidgedMulti + Float m_offset; + Float m_gain; + + unique_ptr<int[]> p; + unique_ptr<Float[][3]> g3; + unique_ptr<Float[][2]> g2; + unique_ptr<Float[]> g1; +}; + +typedef Perlin<float> PerlinF; +typedef Perlin<double> PerlinD; + +template <typename Float> +Float Perlin<Float>::s_curve(Float t) { + return t * t * (3.0 - 2.0 * t); +} + +template <typename Float> +void Perlin<Float>::setup(Float v, int& b0, int& b1, Float& r0, Float& r1) { + int iv = floor(v); + Float fv = v - iv; + + b0 = iv & (PerlinSampleSize - 1); + b1 = (iv + 1) & (PerlinSampleSize - 1); + r0 = fv; + r1 = fv - 1.0; +} + +template <typename Float> +Float Perlin<Float>::at2(Float* q, Float rx, Float ry) { + return rx * q[0] + ry * q[1]; +} + +template <typename Float> +Float Perlin<Float>::at3(Float* q, Float rx, Float ry, Float rz) { + return rx * q[0] + ry * q[1] + rz * q[2]; +} + +template <typename Float> +Perlin<Float>::Perlin() { + m_type = PerlinType::Uninitialized; + m_alpha = 0; + m_amplitude = 0; + m_frequency = 0; + m_seed = 0; + m_gain = 0; + m_beta = 0; + m_offset = 0; + m_bias = 0; + m_octaves = 0; +} + +template <typename Float> +Perlin<Float>::Perlin(unsigned octaves, Float freq, Float amp, Float bias, Float alpha, Float beta, uint64_t seed) { + m_type = PerlinType::Perlin; + m_seed = seed; + + m_octaves = octaves; + m_frequency = freq; + m_amplitude = amp; + m_bias = bias; + m_alpha = alpha; + m_beta = beta; + + // TODO: These ought to be configurable + m_offset = 1.0; + m_gain = 2.0; + + init(m_seed); +} + +template <typename Float> +Perlin<Float>::Perlin(PerlinType type, unsigned octaves, Float freq, Float amp, Float bias, Float alpha, Float beta, uint64_t seed) { + m_type = type; + m_seed = seed; + + m_octaves = octaves; + m_frequency = freq; + m_amplitude = amp; + m_bias = bias; + m_alpha = alpha; + m_beta = beta; + + // TODO: These ought to be configurable + m_offset = 1.0; + m_gain = 2.0; + + init(m_seed); +} + +template <typename Float> +Perlin<Float>::Perlin(Json const& config, uint64_t seed) + : Perlin(config.set("seed", seed)) {} + +template <typename Float> +Perlin<Float>::Perlin(Json const& json) { + m_seed = json.getUInt("seed"); + m_octaves = json.getInt("octaves", 1); + m_frequency = json.getDouble("frequency", 1.0); + m_amplitude = json.getDouble("amplitude", 1.0); + m_bias = json.getDouble("bias", 0.0); + m_alpha = json.getDouble("alpha", 2.0); + m_beta = json.getDouble("beta", 2.0); + + m_offset = json.getDouble("offset", 1.0); + m_gain = json.getDouble("gain", 2.0); + + m_type = PerlinTypeNames.getLeft(json.getString("type")); + + init(m_seed); +} + +template <typename Float> +Perlin<Float>::Perlin(Perlin const& perlin) { + *this = perlin; +} + +template <typename Float> +Perlin<Float>::Perlin(Perlin&& perlin) { + *this = move(perlin); +} + +template <typename Float> +Perlin<Float>& Perlin<Float>::operator=(Perlin const& perlin) { + if (perlin.m_type == PerlinType::Uninitialized) { + m_type = PerlinType::Uninitialized; + p.reset(); + g3.reset(); + g2.reset(); + g1.reset(); + + } else if (this != &perlin) { + m_type = perlin.m_type; + m_seed = perlin.m_seed; + m_octaves = perlin.m_octaves; + m_frequency = perlin.m_frequency; + m_amplitude = perlin.m_amplitude; + m_bias = perlin.m_bias; + m_alpha = perlin.m_alpha; + m_beta = perlin.m_beta; + m_offset = perlin.m_offset; + m_gain = perlin.m_gain; + + p.reset(new int[PerlinSampleSize + PerlinSampleSize + 2]); + g3.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2][3]); + g2.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2][2]); + g1.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2]); + + std::memcpy(p.get(), perlin.p.get(), (PerlinSampleSize + PerlinSampleSize + 2) * sizeof(int)); + std::memcpy(g3.get(), perlin.g3.get(), (PerlinSampleSize + PerlinSampleSize + 2) * sizeof(Float) * 3); + std::memcpy(g2.get(), perlin.g2.get(), (PerlinSampleSize + PerlinSampleSize + 2) * sizeof(Float) * 2); + std::memcpy(g1.get(), perlin.g1.get(), (PerlinSampleSize + PerlinSampleSize + 2) * sizeof(Float)); + } + + return *this; +} + +template <typename Float> +Perlin<Float>& Perlin<Float>::operator=(Perlin&& perlin) { + m_type = perlin.m_type; + m_seed = perlin.m_seed; + m_octaves = perlin.m_octaves; + m_frequency = perlin.m_frequency; + m_amplitude = perlin.m_amplitude; + m_bias = perlin.m_bias; + m_alpha = perlin.m_alpha; + m_beta = perlin.m_beta; + m_offset = perlin.m_offset; + m_gain = perlin.m_gain; + + p = move(perlin.p); + g3 = move(perlin.g3); + g2 = move(perlin.g2); + g1 = move(perlin.g1); + + return *this; +} + +template <typename Float> +Float Perlin<Float>::get(Float x) const { + switch (m_type) { + case PerlinType::Perlin: + return perlin(x); + case PerlinType::Billow: + return billow(x); + case PerlinType::RidgedMulti: + return ridgedMulti(x); + default: + throw PerlinException("::get called on uninitialized Perlin"); + } +} + +template <typename Float> +Float Perlin<Float>::get(Float x, Float y) const { + switch (m_type) { + case PerlinType::Perlin: + return perlin(x, y); + case PerlinType::Billow: + return billow(x, y); + case PerlinType::RidgedMulti: + return ridgedMulti(x, y); + default: + throw PerlinException("::get called on uninitialized Perlin"); + } +} + +template <typename Float> +Float Perlin<Float>::get(Float x, Float y, Float z) const { + switch (m_type) { + case PerlinType::Perlin: + return perlin(x, y, z); + case PerlinType::Billow: + return billow(x, y, z); + case PerlinType::RidgedMulti: + return ridgedMulti(x, y, z); + default: + throw PerlinException("::get called on uninitialized Perlin"); + } +} + +template <typename Float> +PerlinType Perlin<Float>::type() const { + return m_type; +} + +template <typename Float> +unsigned Perlin<Float>::octaves() const { + return m_octaves; +} + +template <typename Float> +Float Perlin<Float>::frequency() const { + return m_frequency; +} + +template <typename Float> +Float Perlin<Float>::amplitude() const { + return m_amplitude; +} + +template <typename Float> +Float Perlin<Float>::bias() const { + return m_bias; +} + +template <typename Float> +Float Perlin<Float>::alpha() const { + return m_alpha; +} + +template <typename Float> +Float Perlin<Float>::beta() const { + return m_beta; +} + +template <typename Float> +Json Perlin<Float>::toJson() const { + return JsonObject{ + {"seed", m_seed}, + {"octaves", m_octaves}, + {"frequency", m_frequency}, + {"amplitude", m_amplitude}, + {"bias", m_bias}, + {"alpha", m_alpha}, + {"beta", m_beta}, + {"offset", m_offset}, + {"gain", m_gain}, + {"type", PerlinTypeNames.getRight(m_type)} + }; +} + +template <typename Float> +inline Float Perlin<Float>::noise1(Float arg) const { + int bx0, bx1; + Float rx0, rx1, sx, u, v; + + setup(arg, bx0, bx1, rx0, rx1); + + sx = s_curve(rx0); + u = rx0 * g1[p[bx0]]; + v = rx1 * g1[p[bx1]]; + + return (lerp(sx, u, v)); +} + +template <typename Float> +inline Float Perlin<Float>::noise2(Float vec[2]) const { + int bx0, bx1, by0, by1, b00, b10, b01, b11; + Float rx0, rx1, ry0, ry1, sx, sy, a, b, u, v; + int i, j; + + setup(vec[0], bx0, bx1, rx0, rx1); + setup(vec[1], by0, by1, ry0, ry1); + + i = p[bx0]; + j = p[bx1]; + + b00 = p[i + by0]; + b10 = p[j + by0]; + b01 = p[i + by1]; + b11 = p[j + by1]; + + sx = s_curve(rx0); + sy = s_curve(ry0); + + u = at2(g2[b00], rx0, ry0); + v = at2(g2[b10], rx1, ry0); + a = lerp(sx, u, v); + + u = at2(g2[b01], rx0, ry1); + v = at2(g2[b11], rx1, ry1); + b = lerp(sx, u, v); + + return lerp(sy, a, b); +} + +template <typename Float> +inline Float Perlin<Float>::noise3(Float vec[3]) const { + int bx0, bx1, by0, by1, bz0, bz1, b00, b10, b01, b11; + Float rx0, rx1, ry0, ry1, rz0, rz1, sx, sy, sz, a, b, c, d, u, v; + int i, j; + + setup(vec[0], bx0, bx1, rx0, rx1); + setup(vec[1], by0, by1, ry0, ry1); + setup(vec[2], bz0, bz1, rz0, rz1); + + i = p[bx0]; + j = p[bx1]; + + b00 = p[i + by0]; + b10 = p[j + by0]; + b01 = p[i + by1]; + b11 = p[j + by1]; + + sx = s_curve(rx0); + sy = s_curve(ry0); + sz = s_curve(rz0); + + u = at3(g3[b00 + bz0], rx0, ry0, rz0); + v = at3(g3[b10 + bz0], rx1, ry0, rz0); + a = lerp(sx, u, v); + + u = at3(g3[b01 + bz0], rx0, ry1, rz0); + v = at3(g3[b11 + bz0], rx1, ry1, rz0); + b = lerp(sx, u, v); + + c = lerp(sy, a, b); + + u = at3(g3[b00 + bz1], rx0, ry0, rz1); + v = at3(g3[b10 + bz1], rx1, ry0, rz1); + a = lerp(sx, u, v); + + u = at3(g3[b01 + bz1], rx0, ry1, rz1); + v = at3(g3[b11 + bz1], rx1, ry1, rz1); + b = lerp(sx, u, v); + + d = lerp(sy, a, b); + + return lerp(sz, c, d); +} + +template <typename Float> +void Perlin<Float>::normalize2(Float v[2]) const { + Float s; + + s = sqrt(v[0] * v[0] + v[1] * v[1]); + if (s == 0.0f) { + v[0] = 1.0f; + v[1] = 0.0f; + } else { + v[0] = v[0] / s; + v[1] = v[1] / s; + } +} + +template <typename Float> +void Perlin<Float>::normalize3(Float v[3]) const { + Float s; + + s = sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + if (s == 0.0f) { + v[0] = 1.0f; + v[1] = 0.0f; + v[2] = 0.0f; + } else { + v[0] = v[0] / s; + v[1] = v[1] / s; + v[2] = v[2] / s; + } +} + +template <typename Float> +void Perlin<Float>::init(uint64_t seed) { + RandomSource randomSource(seed); + + p.reset(new int[PerlinSampleSize + PerlinSampleSize + 2]); + g3.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2][3]); + g2.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2][2]); + g1.reset(new Float[PerlinSampleSize + PerlinSampleSize + 2]); + + int i, j, k; + + for (i = 0; i < PerlinSampleSize; i++) { + p[i] = i; + g1[i] = (Float)(randomSource.randInt(-PerlinSampleSize, PerlinSampleSize)) / PerlinSampleSize; + + for (j = 0; j < 2; j++) + g2[i][j] = (Float)(randomSource.randInt(-PerlinSampleSize, PerlinSampleSize)) / PerlinSampleSize; + normalize2(g2[i]); + + for (j = 0; j < 3; j++) + g3[i][j] = (Float)(randomSource.randInt(-PerlinSampleSize, PerlinSampleSize)) / PerlinSampleSize; + normalize3(g3[i]); + } + + while (--i) { + k = p[i]; + p[i] = p[j = randomSource.randUInt(PerlinSampleSize - 1)]; + p[j] = k; + } + + for (i = 0; i < PerlinSampleSize + 2; i++) { + p[PerlinSampleSize + i] = p[i]; + g1[PerlinSampleSize + i] = g1[i]; + for (j = 0; j < 2; j++) + g2[PerlinSampleSize + i][j] = g2[i][j]; + for (j = 0; j < 3; j++) + g3[PerlinSampleSize + i][j] = g3[i][j]; + } +} + +template <typename Float> +inline Float Perlin<Float>::perlin(Float x) const { + int i; + Float val, sum = 0; + Float p, scale = 1; + + p = x * m_frequency; + for (i = 0; i < m_octaves; i++) { + val = noise1(p); + sum += val / scale; + scale *= m_alpha; + p *= m_beta; + } + return sum * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::perlin(Float x, Float y) const { + int i; + Float val, sum = 0; + Float p[2], scale = 1; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + for (i = 0; i < m_octaves; i++) { + val = noise2(p); + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + } + return sum * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::perlin(Float x, Float y, Float z) const { + int i; + Float val, sum = 0; + Float p[3], scale = 1; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + p[2] = z * m_frequency; + for (i = 0; i < m_octaves; i++) { + val = noise3(p); + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + p[2] *= m_beta; + } + + return sum * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::ridgedMulti(Float x) const { + Float val, sum = 0; + Float scale = 1; + Float weight = 1.0; + + x *= m_frequency; + for (int i = 0; i < m_octaves; ++i) { + val = noise1(x); + + val = m_offset - fabs(val); + val *= val; + val *= weight; + + weight = clamp<Float>(val * m_gain, 0.0, 1.0); + + sum += val / scale; + scale *= m_alpha; + x *= m_beta; + } + + return ((sum * 1.25) - 1.0) * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::ridgedMulti(Float x, Float y) const { + Float val, sum = 0; + Float p[2], scale = 1; + Float weight = 1.0; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + for (int i = 0; i < m_octaves; ++i) { + val = noise2(p); + + val = m_offset - fabs(val); + val *= val; + val *= weight; + + weight = clamp<Float>(val * m_gain, 0.0, 1.0); + + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + } + + return ((sum * 1.25) - 1.0) * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::ridgedMulti(Float x, Float y, Float z) const { + Float val, sum = 0; + Float p[3], scale = 1; + Float weight = 1.0; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + p[2] = z * m_frequency; + for (int i = 0; i < m_octaves; ++i) { + val = noise3(p); + + val = m_offset - fabs(val); + val *= val; + val *= weight; + + weight = clamp<Float>(val * m_gain, 0.0, 1.0); + + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + p[2] *= m_beta; + } + + return ((sum * 1.25) - 1.0) * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::billow(Float x) const { + Float val, sum = 0; + Float p, scale = 1; + + p = x * m_frequency; + for (int i = 0; i < m_octaves; i++) { + val = noise1(p); + val = 2.0 * fabs(val) - 1.0; + + sum += val / scale; + scale *= m_alpha; + p *= m_beta; + } + return (sum + 0.5) * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::billow(Float x, Float y) const { + Float val, sum = 0; + Float p[2], scale = 1; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + for (int i = 0; i < m_octaves; i++) { + val = noise2(p); + val = 2.0 * fabs(val) - 1.0; + + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + } + return (sum + 0.5) * m_amplitude + m_bias; +} + +template <typename Float> +inline Float Perlin<Float>::billow(Float x, Float y, Float z) const { + Float val, sum = 0; + Float p[3], scale = 1; + + p[0] = x * m_frequency; + p[1] = y * m_frequency; + p[2] = z * m_frequency; + for (int i = 0; i < m_octaves; i++) { + val = noise3(p); + val = 2.0 * fabs(val) - 1.0; + + sum += val / scale; + scale *= m_alpha; + p[0] *= m_beta; + p[1] *= m_beta; + p[2] *= m_beta; + } + + return (sum + 0.5) * m_amplitude + m_bias; +} + +} + +#endif diff --git a/source/core/StarPoly.hpp b/source/core/StarPoly.hpp new file mode 100644 index 0000000..58cfca8 --- /dev/null +++ b/source/core/StarPoly.hpp @@ -0,0 +1,750 @@ +#ifndef STAR_POLY_HPP +#define STAR_POLY_HPP + +#include <numeric> + +#include "StarRect.hpp" + +namespace Star { + +template <typename DataType> +class Polygon { +public: + typedef Vector<DataType, 2> Vertex; + typedef Star::Line<DataType, 2> Line; + typedef Star::Box<DataType, 2> Rect; + + struct IntersectResult { + // Whether or not the two objects intersect + bool intersects; + // How much *this* poly must be moved in order to make them not intersect + // anymore + Vertex overlap; + }; + + struct LineIntersectResult { + // Point of intersection + Vertex point; + // t value at the point of intersection of the line that was checked + DataType along; + // Side that the line first intersected, if the line starts inside the + // polygon, this will not be set. + Maybe<size_t> intersectedSide; + }; + + typedef List<Vertex> VertexList; + typedef typename VertexList::iterator iterator; + typedef typename VertexList::const_iterator const_iterator; + + static Polygon convexHull(VertexList points); + static Polygon clip(Polygon inputPoly, Polygon convexClipPoly); + + // Creates a null polygon + Polygon(); + Polygon(Polygon const& rhs); + Polygon(Polygon&& rhs); + + template <typename DataType2> + explicit Polygon(Box<DataType2, 2> const& rect); + + template <typename DataType2> + explicit Polygon(Polygon<DataType2> const& p2); + + // This seems weird, but it isn't. SAT intersection works perfectly well + // with one Poly having only a single vertex. + explicit Polygon(Vertex const& coord); + + // When specifying a polygon using this constructor the list should be in + // counterclockwise order. + explicit Polygon(VertexList const& vertexes); + + Polygon(std::initializer_list<Vertex> vertexes); + + bool isNull() const; + + bool isConvex() const; + float convexArea() const; + + void deduplicateVertexes(float maxDistance); + + void add(Vertex const& a); + void remove(size_t i); + + void clear(); + + VertexList const& vertexes() const; + VertexList& vertexes(); + + size_t sides() const; + + Line side(size_t i) const; + + DataType distance(Vertex const& c) const; + + void translate(Vertex const& c); + + void setCenter(Vertex const& c); + + void rotate(DataType a, Vertex const& c = Vertex()); + + void scale(Vertex const& s, Vertex const& c = Vertex()); + void scale(DataType s, Vertex const& c = Vertex()); + + void flipHorizontal(DataType horizontalPos = DataType()); + void flipVertical(DataType verticalPos = DataType()); + + template <typename DataType2> + void transform(Matrix3<DataType2> const& transMat); + + Vertex const& operator[](size_t i) const; + Vertex& operator[](size_t i); + + bool operator==(Polygon const& rhs) const; + + Polygon& operator=(Polygon const& rhs); + Polygon& operator=(Polygon&& rhs); + + iterator begin(); + const_iterator begin() const; + + iterator end(); + const_iterator end() const; + + // vertex and normal wrap around so that i can never be out of range. + Vertex const& vertex(size_t i) const; + Vertex normal(size_t i) const; + + Vertex center() const; + + // a point in the volume, within min and max y, moved downwards to be a half + // width from the bottom (if that point is within a half width from the + // top, center() is returned) + Vertex bottomCenter() const; + + Rect boundBox() const; + + // Determine winding number of the given point. + int windingNumber(Vertex const& p) const; + + bool contains(Vertex const& p) const; + + // Normal SAT intersection finding the shortest separation of two convex + // polys. + IntersectResult satIntersection(Polygon const& p) const; + + // A directional version of a SAT intersection that will only separate + // parallel to the given direction. If choseSign is true, then the + // separation can occur either with the given direction or opposite it, but + // still parallel. If it is false, separation will always occur in the given + // direction only. + IntersectResult directionalSatIntersection(Polygon const& p, Vertex const& direction, bool chooseSign) const; + + // Returns the closest intersection with the poly, if any. + Maybe<LineIntersectResult> lineIntersection(Line const& l) const; + + bool intersects(Polygon const& p) const; + bool intersects(Line const& l) const; + +private: + // i must be between 0 and m_vertexes.size() - 1 + Line sideAt(size_t i) const; + + VertexList m_vertexes; +}; + +template <typename DataType> +std::ostream& operator<<(std::ostream& os, Polygon<DataType> const& poly); + +typedef Polygon<int> PolyI; +typedef Polygon<float> PolyF; +typedef Polygon<double> PolyD; + +template <typename DataType> +Polygon<DataType> Polygon<DataType>::convexHull(VertexList points) { + if (points.empty()) + return {}; + + auto cross = [](Vertex o, Vertex a, Vertex b) { + return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0]); + }; + sort(points); + + VertexList lower; + for (auto const& point : points) { + while (lower.size() >= 2 && cross(lower[lower.size() - 2], lower[lower.size() - 1], point) <= 0) + lower.removeLast(); + lower.append(point); + } + + VertexList upper; + for (auto const& point : reverseIterate(points)) { + while (upper.size() >= 2 && cross(upper[upper.size() - 2], upper[upper.size() - 1], point) <= 0) + upper.removeLast(); + upper.append(point); + } + + upper.removeLast(); + lower.removeLast(); + lower.appendAll(take(upper)); + return Polygon<DataType>(move(lower)); +} + +template <typename DataType> +Polygon<DataType> Polygon<DataType>::clip(Polygon inputPoly, Polygon convexClipPoly) { + if (inputPoly.sides() == 0) + return inputPoly; + + auto insideEdge = [](Line const& edge, Vertex const& p) { + return ((edge.max() - edge.min()) ^ (p - edge.min())) > 0; + }; + + VertexList outputVertexes = take(inputPoly.m_vertexes); + for (size_t i = 0; i < convexClipPoly.sides(); ++i) { + if (outputVertexes.empty()) + break; + + Line clipEdge = convexClipPoly.sideAt(i); + VertexList inputVertexes = take(outputVertexes); + Vertex s = inputVertexes.last(); + for (Vertex e : inputVertexes) { + if (insideEdge(clipEdge, e)) { + if (!insideEdge(clipEdge, s)) + outputVertexes.append(clipEdge.intersection(Line(s, e)).point); + outputVertexes.append(e); + } else if (insideEdge(clipEdge, s)) { + outputVertexes.append(clipEdge.intersection(Line(s, e)).point); + } + s = e; + } + } + + return Polygon(move(outputVertexes)); +} + +template <typename DataType> +Polygon<DataType>::Polygon() {} + +template <typename DataType> +Polygon<DataType>::Polygon(Polygon const& rhs) + : m_vertexes(rhs.m_vertexes) {} + +template <typename DataType> +Polygon<DataType>::Polygon(Polygon&& rhs) + : m_vertexes(move(rhs.m_vertexes)) {} + +template <typename DataType> +template <typename DataType2> +Polygon<DataType>::Polygon(Box<DataType2, 2> const& rect) { + m_vertexes = { + Vertex(rect.min()), Vertex(rect.max()[0], rect.min()[1]), Vertex(rect.max()), Vertex(rect.min()[0], rect.max()[1])}; +} + +template <typename DataType> +template <typename DataType2> +Polygon<DataType>::Polygon(Polygon<DataType2> const& p2) { + for (auto const& v : p2) + m_vertexes.push_back(Vertex(v)); +} + +template <typename DataType> +Polygon<DataType>::Polygon(Vertex const& coord) { + m_vertexes.push_back(coord); +} + +template <typename DataType> +Polygon<DataType>::Polygon(VertexList const& vertexes) + : m_vertexes(vertexes) {} + +template <typename DataType> +Polygon<DataType>::Polygon(std::initializer_list<Vertex> vertexes) + : m_vertexes(vertexes) {} + +template <typename DataType> +bool Polygon<DataType>::isNull() const { + return m_vertexes.empty(); +} + +template <typename DataType> +bool Polygon<DataType>::isConvex() const { + if (sides() < 2) + return true; + + for (unsigned i = 0; i < sides(); ++i) { + if ((side(i + 1).diff() ^ side(i).diff()) > 0) + return false; + } + + return true; +} + +template <typename DataType> +float Polygon<DataType>::convexArea() const { + float area = 0.0f; + for (size_t i = 0; i < m_vertexes.size(); ++i) { + Vertex const& v1 = m_vertexes[i]; + Vertex const& v2 = i == m_vertexes.size() - 1 ? m_vertexes[0] : m_vertexes[i + 1]; + area += 0.5f * (v1[0] * v2[1] - v1[1] * v2[0]); + } + return area; +} + +template <typename DataType> +void Polygon<DataType>::deduplicateVertexes(float maxDistance) { + if (m_vertexes.empty()) + return; + + float distSquared = square(maxDistance); + VertexList newVertexes = {m_vertexes[0]}; + for (size_t i = 1; i < m_vertexes.size(); ++i) { + if (vmagSquared(m_vertexes[i] - newVertexes.last()) > distSquared) + newVertexes.append(m_vertexes[i]); + } + + if (vmagSquared(newVertexes.first() - newVertexes.last()) <= distSquared) + newVertexes.removeLast(); + + m_vertexes = move(newVertexes); +} + +template <typename DataType> +void Polygon<DataType>::add(Vertex const& a) { + m_vertexes.push_back(a); +} + +template <typename DataType> +void Polygon<DataType>::remove(size_t i) { + auto it = begin() + i % sides(); + m_vertexes.erase(it); +} + +template <typename DataType> +void Polygon<DataType>::clear() { + m_vertexes.clear(); +} + +template <typename DataType> +typename Polygon<DataType>::VertexList const& Polygon<DataType>::vertexes() const { + return m_vertexes; +} + +template <typename DataType> +typename Polygon<DataType>::VertexList& Polygon<DataType>::vertexes() { + return m_vertexes; +} + +template <typename DataType> +size_t Polygon<DataType>::sides() const { + return m_vertexes.size(); +} + +template <typename DataType> +typename Polygon<DataType>::Line Polygon<DataType>::side(size_t i) const { + return sideAt(i % m_vertexes.size()); +} + +template <typename DataType> +DataType Polygon<DataType>::distance(Vertex const& c) const { + if (contains(c)) + return 0; + + DataType dist = highest<DataType>(); + for (size_t i = 0; i < m_vertexes.size(); ++i) + dist = min(dist, sideAt(i).distanceTo(c)); + + return dist; +} + +template <typename DataType> +void Polygon<DataType>::translate(Vertex const& c) { + for (auto& v : m_vertexes) + v += c; +} + +template <typename DataType> +void Polygon<DataType>::setCenter(Vertex const& c) { + translate(c - center()); +} + +template <typename DataType> +void Polygon<DataType>::rotate(DataType a, Vertex const& c) { + for (auto& v : m_vertexes) + v = (v - c).rotate(a) + c; +} + +template <typename DataType> +void Polygon<DataType>::scale(Vertex const& s, Vertex const& c) { + for (auto& v : m_vertexes) + v = vmult((v - c), s) + c; +} + +template <typename DataType> +void Polygon<DataType>::scale(DataType s, Vertex const& c) { + scale(Vertex::filled(s), c); +} + +template <typename DataType> +void Polygon<DataType>::flipHorizontal(DataType horizontalPos) { + scale(Vertex(-1, 1), Vertex(horizontalPos, 0)); + // Reverse vertexes to make sure poly remains counter-clockwise after flip. + std::reverse(m_vertexes.begin(), m_vertexes.end()); +} + +template <typename DataType> +void Polygon<DataType>::flipVertical(DataType verticalPos) { + scale(Vertex(1, -1), Vertex(0, verticalPos)); + // Reverse vertexes to make sure poly remains counter-clockwise after flip. + std::reverse(m_vertexes.begin(), m_vertexes.end()); +} + +template <typename DataType> +template <typename DataType2> +void Polygon<DataType>::transform(Matrix3<DataType2> const& transMat) { + for (auto& v : m_vertexes) + v = transMat.transformVec2(v); +} + +template <typename DataType> +typename Polygon<DataType>::Vertex const& Polygon<DataType>::operator[](size_t i) const { + return m_vertexes[i]; +} + +template <typename DataType> +typename Polygon<DataType>::Vertex& Polygon<DataType>::operator[](size_t i) { + return m_vertexes[i]; +} + +template <typename DataType> +bool Polygon<DataType>::operator==(Polygon<DataType> const& rhs) const { + return m_vertexes == rhs.m_vertexes; +} + +template <typename DataType> +Polygon<DataType>& Polygon<DataType>::operator=(Polygon const& rhs) { + m_vertexes = rhs.m_vertexes; + return *this; +} + +template <typename DataType> +Polygon<DataType>& Polygon<DataType>::operator=(Polygon&& rhs) { + m_vertexes = move(rhs.m_vertexes); + return *this; +} + +template <typename DataType> +typename Polygon<DataType>::iterator Polygon<DataType>::begin() { + return m_vertexes.begin(); +} + +template <typename DataType> +typename Polygon<DataType>::const_iterator Polygon<DataType>::begin() const { + return m_vertexes.begin(); +} + +template <typename DataType> +typename Polygon<DataType>::iterator Polygon<DataType>::end() { + return m_vertexes.end(); +} + +template <typename DataType> +typename Polygon<DataType>::const_iterator Polygon<DataType>::end() const { + return m_vertexes.end(); +} + +template <typename DataType> +typename Polygon<DataType>::Vertex const& Polygon<DataType>::vertex(size_t i) const { + return m_vertexes[i % m_vertexes.size()]; +} + +template <typename DataType> +typename Polygon<DataType>::Vertex Polygon<DataType>::normal(size_t i) const { + Vertex diff = side(i).diff(); + + if (diff == Vertex()) + return Vertex(); + + return diff.rot90().normalized(); +} + +template <typename DataType> +typename Polygon<DataType>::Vertex Polygon<DataType>::center() const { + return std::accumulate(m_vertexes.begin(), m_vertexes.end(), Vertex()) / (DataType)m_vertexes.size(); +} + +template <typename DataType> +typename Polygon<DataType>::Vertex Polygon<DataType>::bottomCenter() const { + if (m_vertexes.size() == 0) + return Vertex(); + Polygon<DataType>::Vertex center = std::accumulate(m_vertexes.begin(), m_vertexes.end(), Vertex()) / (DataType)m_vertexes.size(); + Polygon<DataType>::Vertex bottomLeft = *std::min_element(m_vertexes.begin(), m_vertexes.end()); + Polygon<DataType>::Vertex topRight = *std::max_element(m_vertexes.begin(), m_vertexes.end()); + Polygon<DataType>::Vertex size = topRight - bottomLeft; + if (size.x() > size.y()) + return center; + return Polygon<DataType>::Vertex(center.x(), bottomLeft.y() + size.x() / 2.0f); +} + +template <typename DataType> +auto Polygon<DataType>::boundBox() const -> Rect { + auto bounds = Rect::null(); + for (auto const& v : m_vertexes) + bounds.combine(v); + return bounds; +} + +template <typename DataType> +int Polygon<DataType>::windingNumber(Vertex const& p) const { + + auto isLeft = [](Vertex const& p0, Vertex const& p1, Vertex const& p2) { + return ((p1[0] - p0[0]) * (p2[1] - p0[1]) - (p2[0] - p0[0]) * (p1[1] - p0[1])); + }; + + // the winding number counter + int wn = 0; + + // loop through all edges of the polygon + for (size_t i = 0; i < m_vertexes.size(); ++i) { + auto const& first = m_vertexes[i]; + auto const& second = i == m_vertexes.size() - 1 ? m_vertexes[0] : m_vertexes[i + 1]; + + // start y <= p[1] + if (first[1] <= p[1]) { + if (second[1] > p[1]) { + // an upward crossing + if (isLeft(first, second, p) > 0) { + // p left of edge + // have a valid up intersect + ++wn; + } + } + } else { + // start y > p[1] (no test needed) + if (second[1] <= p[1]) { + // a downward crossing + if (isLeft(first, second, p) < 0) { + // p right of edge + // have a valid down intersect + --wn; + } + } + } + } + + return wn; +} + +template <typename DataType> +bool Polygon<DataType>::contains(Vertex const& p) const { + return windingNumber(p) != 0; +} + +template <typename DataType> +typename Polygon<DataType>::IntersectResult Polygon<DataType>::satIntersection(Polygon const& p) const { + // "Accumulates" the shortest separating distance and axis of this poly and + // the given poly, after projecting all the vertexes of each poly onto a + // given axis. Used by SAT intersection, meant to be called with each tested + // axis. + auto accumSeparator = [this](Polygon const& p, Vertex const& axis, DataType& shortestOverlap, Vertex& finalSepDir) { + DataType myProjectionLow = std::numeric_limits<DataType>::max(); + DataType targetProjectionHigh = std::numeric_limits<DataType>::lowest(); + + for (auto const& v : m_vertexes) { + DataType p = axis[0] * v[0] + axis[1] * v[1]; + if (p < myProjectionLow) + myProjectionLow = p; + } + + for (auto const& v : p.m_vertexes) { + DataType p = axis[0] * v[0] + axis[1] * v[1]; + if (p > targetProjectionHigh) + targetProjectionHigh = p; + } + + float overlap = targetProjectionHigh - myProjectionLow; + if (overlap < shortestOverlap) { + shortestOverlap = overlap; + finalSepDir = axis; + } + }; + + DataType overlap = std::numeric_limits<DataType>::max(); + Vertex separatingDir = Vertex(); + + if (!m_vertexes.empty()) { + Vertex pv = m_vertexes[m_vertexes.size() - 1]; + for (auto const& v : m_vertexes) { + Vertex sideNormal = pv - v; + if (sideNormal != Vertex()) { + sideNormal = sideNormal.rot90().normalized(); + accumSeparator(p, -sideNormal, overlap, separatingDir); + } + pv = v; + } + } + + if (!p.m_vertexes.empty()) { + Vertex pv = p.m_vertexes[p.m_vertexes.size() - 1]; + for (auto const& v : p.m_vertexes) { + Vertex sideNormal = pv - v; + if (sideNormal != Vertex()) { + sideNormal = sideNormal.rot90().normalized(); + accumSeparator(p, sideNormal, overlap, separatingDir); + } + pv = v; + } + } + + IntersectResult isect; + isect.intersects = (overlap > 0); + isect.overlap = separatingDir * overlap; + + return isect; +} + +template <typename DataType> +typename Polygon<DataType>::IntersectResult Polygon<DataType>::directionalSatIntersection( + Polygon const& p, Vertex const& direction, bool chooseSign) const { + // A "directional" version of accumSeparator, that when intersecting only + // ever tries to separate in the given direction. + auto directionalAccumSeparator = [this](Polygon const& p, Vertex axis, DataType& shortestOverlap, + Vertex const& separatingDir, Vertex& finalSepDir, bool chooseDir) { + DataType myProjectionLow = std::numeric_limits<DataType>::max(); + DataType targetProjectionHigh = std::numeric_limits<DataType>::lowest(); + + for (auto const& v : m_vertexes) { + DataType p = axis[0] * v[0] + axis[1] * v[1]; + if (p < myProjectionLow) + myProjectionLow = p; + } + + for (auto const& v : p.m_vertexes) { + DataType p = axis[0] * v[0] + axis[1] * v[1]; + if (p > targetProjectionHigh) + targetProjectionHigh = p; + } + + float overlap = targetProjectionHigh - myProjectionLow; + + // Separation was found, skip the rest of the method. + if (overlap <= 0) { + if (overlap < shortestOverlap) { + shortestOverlap = overlap; + finalSepDir = axis; + } + return; + } + + DataType axisDot = separatingDir * axis; + + // Now, if we don't have separation and the axis is perpendicular to + // requested, we can do nothing, return. + if (axisDot == 0) + return; + + // Separate along the given separating direction enough to separate as + // determined by this axis. + DataType projOverlap = overlap / axisDot; + if (chooseDir) { + DataType absProjOverlap = (projOverlap >= 0) ? projOverlap : -projOverlap; + if (absProjOverlap < shortestOverlap) { + shortestOverlap = absProjOverlap; + finalSepDir = separatingDir * (projOverlap / absProjOverlap); + } + } else if (projOverlap >= 0) { + if (projOverlap < shortestOverlap) { + shortestOverlap = projOverlap; + finalSepDir = separatingDir; + } + } + }; + + DataType overlap = std::numeric_limits<DataType>::max(); + Vertex separatingDir = Vertex(); + + if (!m_vertexes.empty()) { + Vertex pv = m_vertexes[m_vertexes.size() - 1]; + for (auto const& v : m_vertexes) { + Vertex sideNormal = pv - v; + if (sideNormal != Vertex()) { + sideNormal = sideNormal.rot90().normalized(); + directionalAccumSeparator(p, -sideNormal, overlap, direction, separatingDir, chooseSign); + } + pv = v; + } + } + + if (!p.m_vertexes.empty()) { + Vertex pv = p.m_vertexes[p.m_vertexes.size() - 1]; + for (auto const& v : p.m_vertexes) { + Vertex sideNormal = pv - v; + if (sideNormal != Vertex()) { + sideNormal = sideNormal.rot90().normalized(); + directionalAccumSeparator(p, sideNormal, overlap, direction, separatingDir, chooseSign); + } + pv = v; + } + } + + IntersectResult isect; + isect.intersects = (overlap > 0); + isect.overlap = separatingDir * overlap; + + return isect; +} + +template <typename DataType> +auto Polygon<DataType>::lineIntersection(Line const& l) const -> Maybe<LineIntersectResult> { + if (contains(l.min())) + return LineIntersectResult{l.min(), DataType(0), {}}; + + Maybe<LineIntersectResult> nearestIntersection; + for (size_t i = 0; i < m_vertexes.size(); ++i) { + auto intersection = l.intersection(sideAt(i)); + if (intersection.intersects) { + if (!nearestIntersection || intersection.t < nearestIntersection->along) + nearestIntersection = LineIntersectResult{intersection.point, intersection.t, i}; + } + } + return nearestIntersection; +} + +template <typename DataType> +bool Polygon<DataType>::intersects(Polygon const& p) const { + return satIntersection(p).intersects; +} + +template <typename DataType> +bool Polygon<DataType>::intersects(Line const& l) const { + if (contains(l.min()) || contains(l.max())) + return true; + + for (size_t i = 0; i < m_vertexes.size(); ++i) { + if (l.intersects(sideAt(i))) + return true; + } + + return false; +} + +template <typename DataType> +auto Polygon<DataType>::sideAt(size_t i) const -> Line { + if (i == m_vertexes.size() - 1) + return Line(m_vertexes[i], m_vertexes[0]); + else + return Line(m_vertexes[i], m_vertexes[i + 1]); +} + +template <typename DataType> +std::ostream& operator<<(std::ostream& os, Polygon<DataType> const& poly) { + os << "[Poly: "; + for (auto i = poly.begin(); i != poly.end(); ++i) { + if (i != poly.begin()) + os << ", "; + os << *i; + } + os << "]"; + return os; +} + +} + +#endif diff --git a/source/core/StarPythonic.hpp b/source/core/StarPythonic.hpp new file mode 100644 index 0000000..014e14e --- /dev/null +++ b/source/core/StarPythonic.hpp @@ -0,0 +1,603 @@ +#ifndef STAR_PYTHONIC_HPP +#define STAR_PYTHONIC_HPP + +#include "StarAlgorithm.hpp" + +namespace Star { + +// any and all + +template <typename Iterator, typename Functor> +bool any(Iterator iterBegin, Iterator iterEnd, Functor const& f) { + for (; iterBegin != iterEnd; iterBegin++) + if (f(*iterBegin)) + return true; + return false; +} + +template <typename Iterator> +bool any(Iterator const& iterBegin, Iterator const& iterEnd) { + typedef typename std::iterator_traits<Iterator>::value_type IteratorValue; + std::function<bool(IteratorValue)> compare = [](IteratorValue const& i) { return (bool)i; }; + return any(iterBegin, iterEnd, compare); +} + +template <typename Iterable, typename Functor> +bool any(Iterable const& iter, Functor const& f) { + return any(std::begin(iter), std::end(iter), f); +} + +template <typename Iterable> +bool any(Iterable const& iter) { + typedef decltype(*std::begin(iter)) IteratorValue; + std::function<bool(IteratorValue)> compare = [](IteratorValue const& i) { return (bool)i; }; + return any(std::begin(iter), std::end(iter), compare); +} + +template <typename Iterator, typename Functor> +bool all(Iterator iterBegin, Iterator iterEnd, Functor const& f) { + for (; iterBegin != iterEnd; iterBegin++) + if (!f(*iterBegin)) + return false; + return true; +} + +template <typename Iterator> +bool all(Iterator const& iterBegin, Iterator const& iterEnd) { + typedef typename std::iterator_traits<Iterator>::value_type IteratorValue; + std::function<bool(IteratorValue)> compare = [](IteratorValue const& i) { return (bool)i; }; + return all(iterBegin, iterEnd, compare); +} + +template <typename Iterable, typename Functor> +bool all(Iterable const& iter, Functor const& f) { + return all(std::begin(iter), std::end(iter), f); +} + +template <typename Iterable> +bool all(Iterable const& iter) { + typedef decltype(*std::begin(iter)) IteratorValue; + std::function<bool(IteratorValue)> compare = [](IteratorValue const& i) { return (bool)i; }; + return all(std::begin(iter), std::end(iter), compare); +} + +// Python style container slicing + +struct SliceIndex { + SliceIndex() : index(0), given(false) {} + SliceIndex(int i) : index(i), given(true) {} + + int index; + bool given; +}; + +SliceIndex const SliceNil = SliceIndex(); + +// T must have operator[](int), size(), and +// push_back(typeof T::operator[](int())) +template <typename Res, typename In> +Res slice(In const& r, SliceIndex a, SliceIndex b = SliceIndex(), int j = 1) { + int size = (int)r.size(); + int start, end; + + // Throw exception on j == 0? + if (j == 0 || size == 0) + return Res(); + + if (!a.given) { + if (j > 0) + start = 0; + else + start = size - 1; + } else if (a.index < 0) { + if (-a.index > size - 1) + start = 0; + else + start = size - -a.index; + } else { + if (a.index > size) + start = size; + else + start = a.index; + } + + if (!b.given) { + if (j > 0) + end = size; + else + end = -1; + } else if (b.index < 0) { + if (-b.index > size - 1) { + end = -1; + } else { + end = size - -b.index; + } + } else { + if (b.index > size - 1) { + end = size; + } else { + end = b.index; + } + } + + if (start < end && j < 0) + return Res(); + if (start > end && j > 0) + return Res(); + + Res returnSlice; + int i; + for (i = start; i < end; i += j) + returnSlice.push_back(r[i]); + + return returnSlice; +} + +template <typename T> +T slice(T const& r, SliceIndex a, SliceIndex b = SliceIndex(), int j = 1) { + return slice<T, T>(r, a, b, j); +} + +// ZIP + +// Wraps a regular iterator and returns a singleton tuple, as well as +// supporting the iterator protocol that the zip iterator code expects. +template <typename IteratorT> +class ZipWrapperIterator { +private: + IteratorT current; + IteratorT last; + bool atEnd; + +public: + typedef IteratorT Iterator; + typedef decltype(*std::declval<Iterator>()) IteratorValue; + typedef tuple<IteratorValue> value_type; + + ZipWrapperIterator() : atEnd(true) {} + + ZipWrapperIterator(Iterator current, Iterator last) : current(current), last(last) { + atEnd = current == last; + } + + ZipWrapperIterator operator++() { + if (!atEnd) { + ++current; + atEnd = current == last; + } + + return *this; + } + + value_type operator*() const { + return std::tuple<IteratorValue>(*current); + } + + bool operator==(ZipWrapperIterator const& rhs) const { + return (atEnd && rhs.atEnd) || (!atEnd && !rhs.atEnd && current == rhs.current && last == rhs.last); + } + + bool operator!=(ZipWrapperIterator const& rhs) const { + return !(*this == rhs); + } + + explicit operator bool() const { + return !atEnd; + } + + ZipWrapperIterator begin() const { + return *this; + } + + ZipWrapperIterator end() const { + return ZipWrapperIterator(); + } +}; +template <typename IteratorT> +ZipWrapperIterator<IteratorT> makeZipWrapperIterator(IteratorT current, IteratorT end) { + return ZipWrapperIterator<IteratorT>(current, end); +} + +// Takes two ZipIterators / ZipTupleIterators and concatenates them into a +// single iterator that returns the concatenated tuple. +template <typename TailIteratorT, typename HeadIteratorT> +class ZipTupleIterator { +private: + TailIteratorT tailIterator; + HeadIteratorT headIterator; + bool atEnd; + +public: + typedef TailIteratorT TailIterator; + typedef HeadIteratorT HeadIterator; + + typedef decltype(*TailIterator()) TailType; + typedef decltype(*HeadIterator()) HeadType; + + typedef decltype(std::tuple_cat(std::declval<TailType>(), std::declval<HeadType>())) value_type; + + ZipTupleIterator() : atEnd(true) {} + + ZipTupleIterator(TailIterator tailIterator, HeadIterator headIterator) + : tailIterator(tailIterator), headIterator(headIterator) { + atEnd = tailIterator == TailIterator() || headIterator == HeadIterator(); + } + + ZipTupleIterator operator++() { + if (!atEnd) { + ++tailIterator; + ++headIterator; + atEnd = tailIterator == TailIterator() || headIterator == HeadIterator(); + } + + return *this; + } + + value_type operator*() const { + return std::tuple_cat(*tailIterator, *headIterator); + } + + bool operator==(ZipTupleIterator const& rhs) const { + return (atEnd && rhs.atEnd) + || (!atEnd && !rhs.atEnd && tailIterator == rhs.tailIterator && headIterator == rhs.headIterator); + } + + bool operator!=(ZipTupleIterator const& rhs) const { + return !(*this == rhs); + } + + explicit operator bool() const { + return !atEnd; + } + + ZipTupleIterator begin() const { + return *this; + } + + ZipTupleIterator end() const { + return ZipTupleIterator(); + } +}; +template <typename HeadIteratorT, typename TailIteratorT> +ZipTupleIterator<HeadIteratorT, TailIteratorT> makeZipTupleIterator(HeadIteratorT head, TailIteratorT tail) { + return ZipTupleIterator<HeadIteratorT, TailIteratorT>(head, tail); +} + +template <typename Container, typename... Rest> +struct zipIteratorReturn { + typedef ZipTupleIterator<typename zipIteratorReturn<Container>::type, typename zipIteratorReturn<Rest...>::type> type; +}; + +template <typename Container> +struct zipIteratorReturn<Container> { + typedef ZipWrapperIterator<decltype(std::declval<Container>().begin())> type; +}; + +template <typename Container> +typename zipIteratorReturn<Container>::type zipIterator(Container& container) { + return makeZipWrapperIterator(container.begin(), container.end()); +} + +template <typename Container, typename... Rest> +typename zipIteratorReturn<Container, Rest...>::type zipIterator(Container& container, Rest&... rest) { + return makeZipTupleIterator(makeZipWrapperIterator(container.begin(), container.end()), zipIterator(rest...)); +} + +// END ZIP + +// RANGE + +namespace RangeHelper { + + template <typename Diff> + typename std::enable_if<std::is_unsigned<Diff>::value, bool>::type checkIfDiffLessThanZero(Diff) { + return false; + } + + template <typename Diff> + typename std::enable_if<!std::is_unsigned<Diff>::value, bool>::type checkIfDiffLessThanZero(Diff diff) { + return diff < 0; + } +} + +STAR_EXCEPTION(RangeException, StarException); + +template <typename Value, typename Diff = int> +class RangeIterator : public std::iterator<std::random_access_iterator_tag, Value> { +public: + RangeIterator() : m_start(), m_end(), m_diff(1), m_current(), m_stop(true) {} + + RangeIterator(Value min, Value max, Diff diff) + : m_start(min), m_end(max), m_diff(diff), m_current(min), m_stop(false) { + sanity(); + } + + RangeIterator(Value min, Value max) : m_start(min), m_end(max), m_diff(1), m_current(min), m_stop(false) { + sanity(); + } + + RangeIterator(Value max) : m_start(), m_end(max), m_diff(1), m_current(), m_stop(false) { + sanity(); + } + + RangeIterator(RangeIterator const& rhs) { + copy(rhs); + } + + RangeIterator& operator=(RangeIterator const& rhs) { + copy(rhs); + return *this; + } + + RangeIterator& operator+=(Diff steps) { + if ((applySteps(m_current, m_diff * steps) >= m_end) != (RangeHelper::checkIfDiffLessThanZero<Diff>(m_diff))) { + if (!m_stop) { + Diff stepsLeft = stepsBetween(m_current, m_end); + m_current = applySteps(m_current, stepsLeft * m_diff); + m_stop = true; + } + } else { + m_current = applySteps(m_current, steps * m_diff); + } + return *this; + } + + RangeIterator operator-=(Diff steps) const { + m_stop = false; + sanity(); + + if (applySteps(m_current, -(m_diff * steps)) < m_start) + m_current = m_start; + else + m_current = applySteps(m_current, -(m_diff * steps)); + + return *this; + } + + Value operator*() const { + return m_current; + } + + Value const* operator->() const { + return &m_current; + } + + Value operator[](unsigned rhs) const { + // Should return at maximum, the value that this iterator will normally + // reach when at end(). + rhs = std::min(rhs, stepsBetween(m_start, m_end) + 1); + return m_start + rhs * m_diff; + } + + RangeIterator& operator++() { + return operator+=(1); + } + + RangeIterator& operator--() { + return operator-=(1); + } + + RangeIterator operator++(int) { + RangeIterator tmp(*this); + ++this; + return tmp; + } + + RangeIterator operator--(int) { + RangeIterator tmp(*this); + --this; + return tmp; + } + + RangeIterator operator+(Diff steps) const { + RangeIterator copy(*this); + copy += steps; + return copy; + } + + RangeIterator operator-(Diff steps) const { + RangeIterator copy(*this); + copy -= steps; + return copy; + } + + int operator-(RangeIterator const& rhs) const { + if (!sameClass(rhs)) + throw RangeException("Attempted to subtract incompatible ranges."); + + return stepsBetween(rhs.m_current, m_current); + } + + friend RangeIterator operator+(Diff lhs, RangeIterator const& rhs) { + return rhs + lhs; + } + + friend RangeIterator operator-(Diff lhs, RangeIterator const& rhs) { + return rhs - lhs; + } + + bool operator==(RangeIterator const& rhs) const { + return (sameClass(rhs) && m_current == rhs.m_current && m_stop == rhs.m_stop); + } + + bool operator!=(RangeIterator const& rhs) const { + return !(*this == rhs); + } + + bool operator<(RangeIterator const& rhs) const { + return std::tie(m_start, m_end, m_diff, m_current) < std::tie(rhs.m_start, rhs.m_end, rhs.m_diff, rhs.m_current); + } + + bool operator<=(RangeIterator const& rhs) const { + return (*this == rhs) || (*this < rhs); + } + + bool operator>=(RangeIterator const& rhs) const { + return !(*this < rhs); + } + + bool operator>(RangeIterator const& rhs) const { + return !(*this <= rhs); + } + + RangeIterator begin() const { + return RangeIterator(m_start, m_end, m_diff); + } + + RangeIterator end() const { + Diff steps = stepsBetween(m_start, m_end); + RangeIterator res(m_start, m_end, m_diff); + res += steps; + return res; + } + +private: + void copy(RangeIterator const& copy) { + m_start = copy.m_start; + m_end = copy.m_end; + m_diff = copy.m_diff; + m_current = copy.m_current; + m_stop = copy.m_stop; + sanity(); + } + + void sanity() { + if (m_diff == 0) + throw RangeException("Invalid difference in range function."); + + if ((m_end < m_start) != (RangeHelper::checkIfDiffLessThanZero<Diff>(m_diff))) { + if (RangeHelper::checkIfDiffLessThanZero<Diff>(m_diff)) + throw RangeException("Start cannot be less than end if diff is negative."); + throw RangeException("Max cannot be less than min."); + } + + if (m_end == m_start) + m_stop = true; + } + + bool sameClass(RangeIterator const& rhs) const { + return m_start == rhs.m_start && m_end == rhs.m_end && m_diff == rhs.m_diff; + } + + Diff stepsBetween(Value start, Value end) const { + return ((Diff)end - (Diff)start) / m_diff; + } + + Value applySteps(Value start, Diff travel) const { + return (Value)((Diff)start + travel); + } + + Value m_start; + Value m_end; + Diff m_diff; + + Value m_current; + + bool m_stop; +}; + +template <typename Numeric, typename Diff> +RangeIterator<Numeric, Diff> range(Numeric min, Numeric max, Diff diff) { + return RangeIterator<Numeric, Diff>(min, max, diff); +} + +template <typename Numeric, typename Diff = int> +RangeIterator<Numeric, Diff> range(Numeric max) { + return RangeIterator<Numeric, Diff>(max); +} + +template <typename Numeric, typename Diff = int> +RangeIterator<Numeric, Diff> range(Numeric min, Numeric max) { + return RangeIterator<Numeric, Diff>(min, max); +} + +template <typename Numeric, typename Diff> +RangeIterator<Numeric, Diff> rangeInclusive(Numeric min, Numeric max, Diff diff) { + return RangeIterator<Numeric, Diff>(min, (Numeric)((Diff)max + 1), diff); +} + +template <typename Numeric, typename Diff = int> +RangeIterator<Numeric, Diff> rangeInclusive(Numeric max) { + return RangeIterator<Numeric, Diff>((Numeric)((Diff)max + 1)); +} + +template <typename Numeric, typename Diff = int> +RangeIterator<Numeric, Diff> rangeInclusive(Numeric min, Numeric max) { + return RangeIterator<Numeric, Diff>(min, (Numeric)((Diff)max + 1)); +} + +// END RANGE + +// Wraps a forward-iterator to produce {value, index} pairs, similar to +// python's enumerate() +template <typename Iterator> +struct EnumerateIterator { +private: + Iterator current; + Iterator last; + size_t index; + bool atEnd; + +public: + typedef decltype(*std::declval<Iterator>()) IteratorValue; + typedef pair<IteratorValue&, size_t> value_type; + + EnumerateIterator() : index(0), atEnd(true) {} + + EnumerateIterator(Iterator begin, Iterator end) : current(begin), last(end), index(0) { + atEnd = current == last; + } + + EnumerateIterator begin() const { + return *this; + } + + EnumerateIterator end() const { + return EnumerateIterator(); + } + + EnumerateIterator operator++() { + if (!atEnd) { + ++current; + ++index; + + atEnd = current == last; + } + + return *this; + } + + value_type operator*() const { + return {*current, index}; + } + + bool operator==(EnumerateIterator const& rhs) const { + return (atEnd && rhs.atEnd) || (!atEnd && !rhs.atEnd && current == rhs.current && last == rhs.last); + } + + bool operator!=(EnumerateIterator const& rhs) const { + return !(*this == rhs); + } + + explicit operator bool() const { + return !atEnd; + } +}; + +template <typename Iterable> +EnumerateIterator<decltype(std::declval<Iterable>().begin())> enumerateIterator(Iterable& list) { + return EnumerateIterator<decltype(std::declval<Iterable>().begin())>(list.begin(), list.end()); +} + +template <typename ResultContainer, typename Iterable> +ResultContainer enumerateConstruct(Iterable&& list) { + ResultContainer res; + for (auto el : enumerateIterator(list)) + res.push_back(move(el)); + + return res; +} + +} + +#endif diff --git a/source/core/StarRandom.cpp b/source/core/StarRandom.cpp new file mode 100644 index 0000000..abd3984 --- /dev/null +++ b/source/core/StarRandom.cpp @@ -0,0 +1,346 @@ +#include "StarRandom.hpp" +#include "StarThread.hpp" +#include "StarTime.hpp" +#include "StarMathCommon.hpp" + +namespace Star { + +RandomSource::RandomSource() { + init(Random::randu64()); +} + +void RandomSource::init() { + init(Random::randu64()); +} + +RandomSource::RandomSource(uint64_t seed) { + init(seed); +} + +void RandomSource::init(uint64_t seed) { + /* choose random initial m_carry < 809430660 and */ + /* 256 random 32-bit integers for m_data[] */ + m_carry = seed % 809430660; + + m_data[0] = seed; + m_data[1] = seed >> 32; + + for (size_t i = 2; i < 256; ++i) + m_data[i] = 69069 * m_data[i - 2] + 362437; + + m_index = 255; + + // Hard-coded initial skip of random values, to get the random generator + // going. + const unsigned RandomInitialSkip = 32; + for (unsigned i = 0; i < RandomInitialSkip; ++i) + gen32(); +} + +void RandomSource::addEntropy() { + addEntropy(Random::randu64()); +} + +void RandomSource::addEntropy(uint64_t seed) { + // to avoid seed aliasing + seed ^= randu64(); + + // Same algo as init, but bitwise xor with existing data + + m_carry = (m_carry ^ seed) % 809430660; + + m_data[0] ^= seed; + m_data[1] ^= (seed >> 32) ^ seed; + + for (size_t i = 2; i < 256; ++i) + m_data[i] ^= 69069 * m_data[i - 2] + 362437; +} + +uint32_t RandomSource::randu32() { + return gen32(); +} + +uint64_t RandomSource::randu64() { + uint64_t r = randu32(); + r = r << 32; + r = r | randu32(); + return r; +} + +int32_t RandomSource::randi32() { + return (int32_t)(randu32()); +} + +int64_t RandomSource::randi64() { + return (int64_t)(randu64()); +} + +float RandomSource::randf() { + return (randu32() & 0x7fffffff) / 2147483648.0f; +} + +double RandomSource::randd() { + return (randu64() & 0x7fffffffffffffff) / 9223372036854775808.0; +} + +int64_t RandomSource::randInt(int64_t max) { + return randUInt(max); +} + +uint64_t RandomSource::randUInt(uint64_t max) { + uint64_t denom = (uint64_t)(-1) / ((uint64_t)max + 1); + return randu64() / denom; +} + +int64_t RandomSource::randInt(int64_t min, int64_t max) { + if (max < min) + throw StarException("Maximum bound in randInt must be >= minimum bound!"); + return randInt(max - min) + min; +} + +uint64_t RandomSource::randUInt(uint64_t min, uint64_t max) { + if (max < min) + throw StarException("Maximum bound in randUInt must be >= minimum bound!"); + return randUInt(max - min) + min; +} + +float RandomSource::randf(float min, float max) { + if (max < min) + throw StarException("Maximum bound in randf must be >= minimum bound!"); + return randf() * (max - min) + min; +} + +double RandomSource::randd(double min, double max) { + if (max < min) + throw StarException("Maximum bound in randd must be >= minimum bound!"); + return randd() * (max - min) + min; +} + +bool RandomSource::randb() { + uint32_t v = gen32(); + bool parity = false; + while (v) { + parity = !parity; + v = v & (v - 1); + } + return parity; +} + +void RandomSource::randBytes(char* buf, size_t len) { + while (len) { + uint32_t ui = gen32(); + for (size_t i = 0; i < 4; ++i) { + if (len) { + *buf = (char)(ui >> (i * 8)); + --len; + ++buf; + } + } + } +} + +ByteArray RandomSource::randBytes(size_t len) { + ByteArray array(len, 0); + randBytes(array.ptr(), len); + return array; +} + +// normal distribution via Box-Muller +float RandomSource::nrandf(float stddev, float mean) { + float rand1, rand2, distSqr; + do { + rand1 = 2 * randf() - 1; + rand2 = 2 * randf() - 1; + distSqr = rand1 * rand1 + rand2 * rand2; + } while (distSqr >= 1); + + float mapping = std::sqrt(-2 * std::log(distSqr) / distSqr); + return (rand1 * mapping * stddev + mean); +} + +double RandomSource::nrandd(double stddev, double mean) { + double rand1, rand2, distSqr; + do { + rand1 = 2 * randd() - 1; + rand2 = 2 * randd() - 1; + distSqr = rand1 * rand1 + rand2 * rand2; + } while (distSqr >= 1); + + double mapping = std::sqrt(-2 * std::log(distSqr) / distSqr); + return (rand1 * mapping * stddev + mean); +} + +int64_t RandomSource::stochasticRound(double val) { + double fpart = val - floor(val); + if (randd() < fpart) + return ceil(val); + else + return floor(val); +} + +uint32_t RandomSource::gen32() { + uint64_t a = 809430660; + uint64_t t = a * m_data[++m_index] + m_carry; + + m_carry = (t >> 32); + m_data[m_index] = t; + + return t; +} + +namespace Random { + static Maybe<RandomSource> g_randSource; + static Mutex g_randMutex; + + static uint64_t produceRandomSeed() { + int64_t seed = Time::monotonicTicks(); + seed *= 1099511628211; + seed ^= (((int64_t)rand()) << 32) | ((int64_t)rand()); + return seed; + } + + void doInit(uint64_t seed) { + g_randSource = RandomSource(seed); + // Also set the C stdlib random seed + srand(seed); + } + + void checkInit() { + // Mutex must already be held + if (!g_randSource) { + doInit(produceRandomSeed()); + } + } + + void init() { + MutexLocker locker(g_randMutex); + doInit(produceRandomSeed()); + } + + void init(uint64_t seed) { + MutexLocker locker(g_randMutex); + doInit(seed); + } + + void addEntropy() { + MutexLocker locker(g_randMutex); + checkInit(); + g_randSource->addEntropy(produceRandomSeed()); + } + + void addEntropy(uint64_t seed) { + MutexLocker locker(g_randMutex); + checkInit(); + g_randSource->addEntropy(seed); + } + + uint32_t randu32() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randu32(); + } + + uint64_t randu64() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randu64(); + } + + int32_t randi32() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randi32(); + } + + int64_t randi64() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randi64(); + } + + float randf() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randf(); + } + + double randd() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randd(); + } + + float randf(float min, float max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randf(min, max); + } + + double randd(double min, double max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randd(min, max); + } + + bool randb() { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randb(); + } + + long long randInt(long long max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randInt(max); + } + + unsigned long long randUInt(unsigned long long max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randUInt(max); + } + + long long randInt(long long min, long long max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randInt(min, max); + } + + unsigned long long randUInt(unsigned long long min, unsigned long long max) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randUInt(min, max); + } + + float nrandf(float stddev, float mean) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->nrandf(stddev, mean); + } + + double nrandd(double stddev, double mean) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->nrandd(stddev, mean); + } + + int64_t stochasticRound(double val) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->stochasticRound(val); + } + + void randBytes(char* buf, size_t len) { + MutexLocker locker(g_randMutex); + checkInit(); + g_randSource->randBytes(buf, len); + } + + ByteArray randBytes(size_t len) { + MutexLocker locker(g_randMutex); + checkInit(); + return g_randSource->randBytes(len); + } +} + +} diff --git a/source/core/StarRandom.hpp b/source/core/StarRandom.hpp new file mode 100644 index 0000000..b3c8ba2 --- /dev/null +++ b/source/core/StarRandom.hpp @@ -0,0 +1,216 @@ +#ifndef STAR_RANDOM_HPP +#define STAR_RANDOM_HPP + +#include "StarStaticRandom.hpp" +#include "StarByteArray.hpp" + +namespace Star { + +STAR_EXCEPTION(RandomException, StarException); + +// Deterministic random number source. Uses multiply-with-carry algorithm. +// Much higher quality than the predictable random number generators. Not +// thread safe (won't crash or anything, but might return less than optimal +// values). +class RandomSource { +public: + // Generates a RandomSource with a seed from Random::randu64() + RandomSource(); + RandomSource(uint64_t seed); + + // Re-initializes the random number generator using the given seed. It is + // exactly equivalent to constructing a new RandomSource, just using the same + // buffer. + void init(); + void init(uint64_t seed); + + void addEntropy(); + void addEntropy(uint64_t seed); + + uint32_t randu32(); + uint64_t randu64(); + + int32_t randi32(); + int64_t randi64(); + + // Generates values in the range [0.0, 1.0] + float randf(); + // Generates values in the range [0.0, 1.0] + double randd(); + + // Random integer from [0, max], max must be >= 0 + int64_t randInt(int64_t max); + uint64_t randUInt(uint64_t max); + + // Random integer from [min, max] + int64_t randInt(int64_t min, int64_t max); + uint64_t randUInt(uint64_t min, uint64_t max); + + float randf(float min, float max); + double randd(double min, double max); + + bool randb(); + + // Generates values via normal distribution with box-muller algorithm + float nrandf(float stddev = 1.0f, float mean = 0.0f); + double nrandd(double stddev = 1.0, double mean = 0.0); + + // Round a fractional value statistically towards the floor or ceiling. For + // example, if a value is 5.2, 80% of the time it will round to 5, but 20% of + // the time it will round to 6. + int64_t stochasticRound(double val); + + void randBytes(char* buf, size_t len); + ByteArray randBytes(size_t len); + + // Pick a random value out of a container + template <typename Container> + typename Container::value_type const& randFrom(Container const& container); + template <typename Container> + typename Container::value_type& randFrom(Container& container); + template <typename Container> + typename Container::value_type randValueFrom(Container const& container); + template <typename Container> + typename Container::value_type randValueFrom(Container const& container, typename Container::value_type const& defaultVal); + + template <typename Container> + void shuffle(Container& container); + +private: + uint32_t gen32(); + + uint32_t m_data[256]; + uint32_t m_carry; + uint8_t m_index; +}; + +// Global static versions of the methods in RandomSource. It is not necessary +// to initialize the global RandomSource manually, it will be automatically +// initialized with a random seed on first use if it is not already initialized. +namespace Random { + void init(); + void init(uint64_t seed); + + void addEntropy(); + void addEntropy(uint64_t seed); + + uint32_t randu32(); + uint64_t randu64(); + int32_t randi32(); + int64_t randi64(); + float randf(); + double randd(); + long long randInt(long long max); + unsigned long long randUInt(unsigned long long max); + long long randInt(long long min, long long max); + unsigned long long randUInt(unsigned long long min, unsigned long long max); + float randf(float min, float max); + double randd(double min, double max); + bool randb(); + + float nrandf(float stddev = 1.0f, float mean = 0.0f); + double nrandd(double stddev = 1.0, double mean = 0.0); + + int64_t stochasticRound(double val); + + void randBytes(char* buf, size_t len); + ByteArray randBytes(size_t len); + + template <typename Container> + typename Container::value_type const& randFrom(Container const& container); + template <typename Container> + typename Container::value_type& randFrom(Container& container); + template <typename Container> + typename Container::value_type randValueFrom(Container const& container); + template <typename Container> + typename Container::value_type randValueFrom(Container const& container, typename Container::value_type const& defaultVal); + + template <typename Container> + void shuffle(Container& container); +} + +template <typename Container> +typename Container::value_type const& RandomSource::randFrom(Container const& container) { + if (container.empty()) + throw RandomException("Empty container in randFrom"); + + auto i = container.begin(); + std::advance(i, randUInt(container.size() - 1)); + return *i; +} + +template <typename Container> +typename Container::value_type& RandomSource::randFrom(Container& container) { + if (container.empty()) + throw RandomException("Empty container in randFrom"); + + auto i = container.begin(); + std::advance(i, randUInt(container.size() - 1)); + return *i; +} + +template <typename Container> +typename Container::value_type const& Random::randFrom(Container const& container) { + if (container.empty()) + throw RandomException("Empty container in randFrom"); + + auto i = container.begin(); + std::advance(i, Random::randUInt(container.size() - 1)); + return *i; +} + +template <typename Container> +typename Container::value_type& Random::randFrom(Container& container) { + if (container.empty()) + throw RandomException("Empty container in randFrom"); + + auto i = container.begin(); + std::advance(i, Random::randUInt(container.size() - 1)); + return *i; +} + +template <typename Container> +typename Container::value_type RandomSource::randValueFrom(Container const& container) { + return randValueFrom(container, typename Container::value_type()); +} + +template <typename Container> +typename Container::value_type RandomSource::randValueFrom( + Container const& container, typename Container::value_type const& defaultVal) { + if (container.empty()) + return defaultVal; + + auto i = container.begin(); + std::advance(i, randInt(container.size() - 1)); + return *i; +} + +template <typename Container> +void RandomSource::shuffle(Container& container) { + std::random_shuffle(container.begin(), container.end(), [this](size_t max) { return randUInt(max - 1); }); +} + +template <typename Container> +typename Container::value_type Random::randValueFrom(Container const& container) { + return randValueFrom(container, typename Container::value_type()); +} + +template <typename Container> +typename Container::value_type Random::randValueFrom( + Container const& container, typename Container::value_type const& defaultVal) { + if (container.empty()) + return defaultVal; + + auto i = container.begin(); + std::advance(i, Random::randInt(container.size() - 1)); + return *i; +} + +template <typename Container> +void Random::shuffle(Container& container) { + std::random_shuffle(container.begin(), container.end(), [](size_t max) { return Random::randUInt(max - 1); }); +} + +} + +#endif diff --git a/source/core/StarRandomPoint.hpp b/source/core/StarRandomPoint.hpp new file mode 100644 index 0000000..ff7d354 --- /dev/null +++ b/source/core/StarRandomPoint.hpp @@ -0,0 +1,79 @@ +#ifndef STAR_RANDOM_POINT_HPP +#define STAR_RANDOM_POINT_HPP + +#include "StarRandom.hpp" +#include "StarPoly.hpp" +#include "StarTtlCache.hpp" + +namespace Star { + +// An "infinite" generator of points on a 2d plane, generated cell by cell with +// an upper and lower cell density range. Each point is generated in a +// predictable way sector by sector, as long as the generator function is +// predictable and uses the RandomSource in a predictable way. Useful for +// things like starfields, fields of debris, random object placement, etc. + +template <typename PointData> +class Random2dPointGenerator { +public: + typedef List<pair<Vec2F, PointData>> PointSet; + + Random2dPointGenerator(uint64_t seed, float cellSize, Vec2I const& densityRange); + + // Each point will in the area will be generated in a predictable order, and + // if the callback uses the RandomSource in a predictable way, will generate + // the same field for every call. + template <typename PointCallback> + PointSet generate(PolyF const& area, PointCallback callback); + +private: + HashTtlCache<Vec2F, PointSet> m_cache; + + uint64_t m_seed; + float m_cellSize; + Vec2I m_densityRange; +}; + +template <typename PointData> +inline Random2dPointGenerator<PointData>::Random2dPointGenerator(uint64_t seed, float cellSize, Vec2I const& densityRange) + : m_seed(seed), m_cellSize(cellSize), m_densityRange(densityRange) {} + +template <typename PointData> +template <typename PointCallback> +auto Random2dPointGenerator<PointData>::generate(PolyF const& area, PointCallback callback) -> PointSet { + auto bound = area.boundBox(); + int64_t sectorXMin = std::floor(bound.xMin() / m_cellSize); + int64_t sectorYMin = std::floor(bound.yMin() / m_cellSize); + int64_t sectorXMax = std::ceil(bound.xMax() / m_cellSize); + int64_t sectorYMax = std::ceil(bound.yMax() / m_cellSize); + + PointSet finalResult; + + for (int64_t x = sectorXMin; x <= sectorXMax; ++x) { + for (int64_t y = sectorYMin; y <= sectorYMax; ++y) { + auto sector = RectF::withSize({x * m_cellSize, y * m_cellSize}, Vec2F::filled(m_cellSize)); + if (!area.intersects(PolyF(sector))) + continue; + + finalResult.appendAll(m_cache.get(Vec2F(x, y), [&](Vec2F const&) { + PointSet sectorResult; + + RandomSource sectorRandomness(staticRandomU64(m_seed, x, y)); + + unsigned max = sectorRandomness.randInt(m_densityRange[0], m_densityRange[1]); + for (unsigned i = 0; i < max; ++i) { + Vec2F pointPos = Vec2F(x + sectorRandomness.randf(), y + sectorRandomness.randf()) * m_cellSize; + sectorResult.append(pair<Vec2F, PointData>(pointPos, callback(sectorRandomness))); + } + + return sectorResult; + })); + } + } + + return finalResult; +} + +} + +#endif diff --git a/source/core/StarRect.hpp b/source/core/StarRect.hpp new file mode 100644 index 0000000..38bbabc --- /dev/null +++ b/source/core/StarRect.hpp @@ -0,0 +1,1068 @@ +#ifndef STAR_RECT_HPP +#define STAR_RECT_HPP + +#include "StarLine.hpp" +#include "StarList.hpp" + +namespace Star { + +// Axis aligned box that can be used as a bounding volume. +template <typename T, size_t N> +class Box { +public: + typedef Vector<T, N> Coord; + typedef Star::Line<T, N> Line; + typedef typename Line::IntersectResult LineIntersectResult; + + template <size_t P, typename T2 = void> + using Enable2D = typename std::enable_if<P == 2 && N == P, T2>::type; + + struct IntersectResult { + // Whether or not the two objects intersect + bool intersects; + // How much *this* box must be moved in order to make them not intersect + // anymore + Coord overlap; + // Whether or not the intersection is touching only. No overlap. + bool glances; + }; + + static Box null(); + static Box inf(); + + // Returns an integral aligned box that at least contains the given floating + // point box. + template <typename Box2> + static Box integral(Box2 const& box); + + // Returns an integral aligned box that is equal to the given box rounded to + // the nearest whole number (does not necessarily contain the given box). + template <typename Box2> + static Box round(Box2 const& box); + + template <typename... TN> + static Box boundBoxOf(TN const&... list); + + template <typename Collection> + static Box boundBoxOfPoints(Collection const& collection); + + static Box withSize(Coord const& min, Coord const& size); + static Box withCenter(Coord const& center, Coord const& size); + + Box(); + Box(Coord const& min, Coord const& max); + Box(Box const& b); + + template <typename T2> + explicit Box(Box<T2, N> const& b); + + // Is equal to null() + bool isNull() const; + + // One or more dimensions are of negative magnitude + bool isNegative() const; + + // One or more dimensions are of zero or negative magnitude + bool isEmpty() const; + + // Sets the bounding box equal to one containing the given bounding box and + // the current one. + void combine(Box const& box); + Box combined(Box const& box) const; + + // Sets the bounding box equal to one containing the current bounding box and + // the given point. + void combine(Coord const& point); + Box combined(Coord const& point) const; + + // Sets the bounding box equal to the intersection of this one and the given + // one. If there is no intersection than the box becomes negative in that + // dimension. + void limit(Box const& box); + Box limited(Box const& box); + + // If any range has a min < max, swap them to make it non-null. + void makePositive(); + + // Sets any empty (or negative) dimensions in the bounding box to the + // corresponding range in the given bounding box. If the bounding box is not + // empty in any dimension, then this has no effect. + void rangeSetIfEmpty(Box const& b); + + Coord size() const; + T size(size_t dim) const; + + // Sets bound box to the minimum bound box necessary to both have the given + // aspect ratio and contain the current bounding box. + void setAspect(Coord as, bool shrink = false); + + void makeCube(); + + Coord center() const; + void setCenter(Coord const& c); + + void translate(Coord const& c); + Box translated(Coord const& c) const; + + // Translate the Box the minimum distance so that it includes the given point + void translateToInclude(Coord const& coord, Coord const& padding = Coord()); + + Vector<T, 2> range(size_t dim) const; + void setRange(size_t dim, Vector<T, 2> v); + void combineRange(size_t dim, Vector<T, 2> v); + void limitRange(size_t dim, Vector<T, 2> v); + + // Expand from center. + void expand(T factor); + Box expanded(T factor) const; + + // Expand from center. + void expand(Coord const& factor); + Box expanded(Vector<T, N> const& factor) const; + + // Scale around origin. + void scale(T factor); + Box scaled(T factor) const; + + // Scale around origin. + void scale(Coord const& factor); + Box scaled(Vector<T, N> const& factor) const; + + // Increase all dimensions by a constant amount on all sides + void pad(T amount); + Box padded(T amount) const; + + // Increase all dimensions by a constant amount + void pad(Coord const& amount); + Box padded(Vector<T, N> const& amount) const; + + // Opposite of pad + void trim(T amount); + Box trimmed(T amount) const; + + // Opposite of pad + void trim(Coord const& amount); + Box trimmed(Vector<T, N> const& amount) const; + + // Flip around some dimension (may make box have negative volume) + void flip(size_t dimension); + Box flipped(size_t dimension) const; + + Coord const& min() const; + Coord const& max() const; + + Coord& min(); + Coord& max(); + + void setMin(Coord const& c); + void setMax(Coord const& c); + + T volume() const; + Box overlap(Box const& b) const; + + IntersectResult intersection(Box const& b) const; + bool intersects(Box const& b, bool includeEdges = true) const; + + bool contains(Coord const& p, bool includeEdges = true) const; + bool contains(Box const& b, bool includeEdges = true) const; + + // A version of contains that includes the min edges but not the max edges, + // useful to select based on adjoining boxes without overlap. + bool belongs(Coord const& p) const; + + bool containsEpsilon(Coord const& p, unsigned epsilons = 2) const; + bool containsEpsilon(Box const& b, unsigned epsilons = 2) const; + + bool operator==(Box const& ref) const; + bool operator!=(Box const& ref) const; + + // Find Coord inside box nearest to + Coord nearestCoordTo(Coord const& c) const; + + // Find the coord in normalized space for this rect, so that 0 is the minimum + // and 1 is the maximum. + Coord normal(Coord const& coord) const; + + // The invers of normal, find the real space position of this normalized + // coordinate. + Coord eval(Coord const& normalizedCoord) const; + + // 2D Only + + // Slightly different to make ctor work + template <size_t P = N, class = Enable2D<P>> + Box(T minx, T miny, T maxx, T maxy); + + template <size_t P = N> + Enable2D<P, T> xMin() const; + template <size_t P = N> + Enable2D<P, T> xMax() const; + template <size_t P = N> + Enable2D<P, T> yMin() const; + template <size_t P = N> + Enable2D<P, T> yMax() const; + + template <size_t P = N> + Enable2D<P> setXMin(T xMin); + template <size_t P = N> + Enable2D<P> setXMax(T xMax); + template <size_t P = N> + Enable2D<P> setYMin(T yMin); + template <size_t P = N> + Enable2D<P> setYMax(T yMax); + + template <size_t P = N> + Enable2D<P, T> width() const; + template <size_t P = N> + Enable2D<P, T> height() const; + + template <size_t P = N> + Enable2D<P, void> translate(T x, T y); + template <size_t P = N> + Enable2D<P, void> translateToInclude(T x, T y, T xPadding = 0, T yPadding = 0); + template <size_t P = N> + Enable2D<P, void> scale(T x, T y); + template <size_t P = N> + Enable2D<P, void> expand(T x, T y); + template <size_t P = N> + Enable2D<P, void> flipHorizontal(); + template <size_t P = N> + Enable2D<P, void> flipVertical(); + + template <size_t P = N> + Enable2D<P, Array<Line, 4>> edges() const; + template <size_t P = N> + Enable2D<P, bool> intersects(Line const& l) const; + template <size_t P = N> + Enable2D<P, bool> intersectsCircle(Coord const& position, T radius) const; + template <size_t P = N> + Enable2D<P, LineIntersectResult> edgeIntersection(Line const& l) const; + + // Returns a list of areas that are in this rect but not in the given rect. + // Extra Credit: Implement this method for arbitrary dimensions. + template <size_t P = N> + Enable2D<P, List<Box>> subtract(Box const& rect) const; + +protected: + template <typename... TN> + static void combineThings(Box& b, Coord const& point, TN const&... rest); + + template <typename... TN> + static void combineThings(Box& b, Box const& box, TN const&... rest); + + template <typename... TN> + static void combineThings(Box& b); + + Coord m_min; + Coord m_max; +}; + +template <typename T, size_t N> +std::ostream& operator<<(std::ostream& os, Box<T, N> const& box); + +template<typename T> +using Rect = Box<T, 2>; + +typedef Rect<int> RectI; +typedef Rect<unsigned> RectU; +typedef Rect<float> RectF; +typedef Rect<double> RectD; + +template <typename T, size_t N> +Box<T, N> Box<T, N>::null() { + return Box(Coord::filled(std::numeric_limits<T>::max()), Coord::filled(std::numeric_limits<T>::lowest())); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::inf() { + return Box(Coord::filled(std::numeric_limits<T>::lowest()), Coord::filled(std::numeric_limits<T>::max())); +} + +template <typename T, size_t N> +template <typename Box2> +Box<T, N> Box<T, N>::integral(Box2 const& box) { + return Box(Coord::floor(box.min()), Coord::ceil(box.max())); +} + +template <typename T, size_t N> +template <typename Box2> +Box<T, N> Box<T, N>::round(Box2 const& box) { + return Box(Coord::round(box.min()), Coord::round(box.max())); +} + +template <typename T, size_t N> +template <typename... TN> +Box<T, N> Box<T, N>::boundBoxOf(TN const&... list) { + Box b = null(); + combineThings(b, list...); + return b; +} + +template <typename T, size_t N> +template <typename Collection> +Box<T, N> Box<T, N>::boundBoxOfPoints(Collection const& collection) { + Box b = null(); + for (auto const& point : collection) + b.combine(Coord(point)); + return b; +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::withSize(Coord const& min, Coord const& size) { + return Box(min, min + size); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::withCenter(Coord const& center, Coord const& size) { + return Box(center - size / 2, center + size / 2); +} + +template <typename T, size_t N> +Box<T, N>::Box() {} + +template <typename T, size_t N> +Box<T, N>::Box(Coord const& min, Coord const& max) + : m_min(min), m_max(max) {} + +template <typename T, size_t N> +Box<T, N>::Box(Box const& b) + : m_min(b.min()), m_max(b.max()) {} + +template <typename T, size_t N> +template <typename T2> +Box<T, N>::Box(Box<T2, N> const& b) + : m_min(b.min()), m_max(b.max()) {} + +template <typename T, size_t N> +bool Box<T, N>::isNull() const { + return m_min == Coord::filled(std::numeric_limits<T>::max()) + && m_max == Coord::filled(std::numeric_limits<T>::lowest()); +} + +template <typename T, size_t N> +bool Box<T, N>::isNegative() const { + for (size_t i = 0; i < N; ++i) { + if (m_max[i] < m_min[i]) + return true; + } + return false; +} + +template <typename T, size_t N> +bool Box<T, N>::isEmpty() const { + for (size_t i = 0; i < N; ++i) { + if (m_max[i] <= m_min[i]) + return true; + } + return false; +} + +template <typename T, size_t N> +void Box<T, N>::combine(Box const& box) { + m_min = box.m_min.piecewiseMin(m_min); + m_max = box.m_max.piecewiseMax(m_max); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::combined(Box const& box) const { + auto b = *this; + b.combine(box); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::combine(Coord const& point) { + m_min = m_min.piecewiseMin(point); + m_max = m_max.piecewiseMax(point); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::combined(Coord const& point) const { + auto b = *this; + b.combine(point); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::limit(Box const& box) { + m_min = m_min.piecewiseMax(box.m_min); + m_max = m_max.piecewiseMin(box.m_max); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::limited(Box const& box) { + auto b = *this; + b.limit(box); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::makePositive() { + for (size_t i = 0; i < N; ++i) { + if (m_max[i] < m_min[i]) + std::swap(m_max[i], m_min[i]); + } +} + +template <typename T, size_t N> +void Box<T, N>::rangeSetIfEmpty(Box const& b) { + for (size_t i = 0; i < N; ++i) { + if (m_max[i] <= m_min[i]) + setRange(i, b.range(i)); + } +} + +template <typename T, size_t N> +void Box<T, N>::makeCube() { + setAspect(Coord::filled(1)); +} + +template <typename T, size_t N> +auto Box<T, N>::size() const -> Coord { + return m_max - m_min; +} + +template <typename T, size_t N> +T Box<T, N>::size(size_t dim) const { + return m_max[dim] - m_min[dim]; +} + +template <typename T, size_t N> +void Box<T, N>::setAspect(Coord as, bool shrink) { + Coord nBox = (m_max - m_min).piecewiseDivide(as); + Coord extent; + if (shrink) + extent = Coord::filled(nBox.min()); + else + extent = Coord::filled(nBox.max()); + extent = extent.piecewiseMult(as); + Coord center = (m_max + m_min) / 2; + m_max = center + extent / 2; + m_min = center - extent / 2; +} + +template <typename T, size_t N> +auto Box<T, N>::center() const -> Coord { + return (m_min + m_max) / 2; +} + +template <typename T, size_t N> +void Box<T, N>::setCenter(Coord const& c) { + translate(c - center()); +} + +template <typename T, size_t N> +void Box<T, N>::translate(Coord const& c) { + m_min += c; + m_max += c; +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::translated(Coord const& c) const { + auto b = *this; + b.translate(c); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::translateToInclude(Coord const& coord, Coord const& padding) { + Coord translation; + for (size_t i = 0; i < N; ++i) { + if (coord[i] < m_min[i] + padding[i]) + translation[i] = coord[i] - m_min[i] - padding[i]; + else if (coord[i] > m_max[i] - padding[i]) + translation[i] = coord[i] - m_max[i] + padding[i]; + } + translate(translation); +} + +template <typename T, size_t N> +Vector<T, 2> Box<T, N>::range(size_t dim) const { + return Coord(m_min[dim], m_max[dim]); +} + +template <typename T, size_t N> +void Box<T, N>::setRange(size_t dim, Vector<T, 2> v) { + m_min[dim] = v[0]; + m_max[dim] = v[1]; +} + +template <typename T, size_t N> +void Box<T, N>::combineRange(size_t dim, Vector<T, 2> v) { + m_min[dim] = std::min(m_min[dim], v[0]); + m_max[dim] = std::max(m_max[dim], v[1]); +} + +template <typename T, size_t N> +void Box<T, N>::limitRange(size_t dim, Vector<T, 2> v) { + m_min[dim] = std::max(m_min[dim], v[0]); + m_max[dim] = std::min(m_max[dim], v[1]); +} + +template <typename T, size_t N> +void Box<T, N>::expand(T factor) { + for (size_t i = 0; i < N; ++i) { + auto rng = range(i); + T center = rng.sum() / 2; + T newDist = (rng[1] - rng[0]) * factor; + setRange(i, Coord(center - newDist / 2, center + newDist / 2)); + } +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::expanded(T factor) const { + auto b = *this; + b.expand(factor); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::expand(Coord const& factor) { + for (size_t i = 0; i < N; ++i) { + auto rng = range(i); + T center = rng.sum() / 2; + T newDist = (rng[1] - rng[0]) * factor[i]; + setRange(i, Coord(center - newDist / 2, center + newDist / 2)); + } +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::expanded(Coord const& factor) const { + auto b = *this; + b.expand(factor); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::scale(T factor) { + for (size_t i = 0; i < N; ++i) + setRange(i, range(i) * factor); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::scaled(T factor) const { + auto b = *this; + b.scale(factor); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::scale(Coord const& factor) { + for (size_t i = 0; i < N; ++i) + setRange(i, range(i) * factor[i]); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::scaled(Coord const& factor) const { + auto b = *this; + b.scale(factor); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::pad(T amount) { + for (size_t i = 0; i < N; ++i) { + m_min[i] -= amount; + m_max[i] += amount; + } +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::padded(T amount) const { + auto b = *this; + b.pad(amount); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::pad(Coord const& amount) { + for (size_t i = 0; i < N; ++i) { + m_min[i] -= amount[i]; + m_max[i] += amount[i]; + } +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::padded(Coord const& amount) const { + auto b = *this; + b.pad(amount); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::trim(T amount) { + pad(-amount); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::trimmed(T amount) const { + auto b = *this; + b.trim(amount); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::trim(Coord const& amount) { + pad(-amount); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::trimmed(Coord const& amount) const { + auto b = *this; + b.trim(amount); + return b; +} + +template <typename T, size_t N> +void Box<T, N>::flip(size_t dimension) { + std::swap(m_min[dimension], m_max[dimension]); +} + +template <typename T, size_t N> +Box<T, N> Box<T, N>::flipped(size_t dimension) const { + auto b = *this; + b.flip(dimension); + return b; +} + +template <typename T, size_t N> +auto Box<T, N>::normal(Coord const& coord) const -> Coord { + return (coord - m_min).piecewiseDivide(m_max - m_min); +} + +template <typename T, size_t N> +auto Box<T, N>::eval(Coord const& normalizedCoord) const -> Coord { + return normalizedCoord.piecewiseMultiply(m_max - m_min) + m_min; +} + +template <typename T, size_t N> +auto Box<T, N>::min() const -> Coord const & { + return m_min; +} + +template <typename T, size_t N> +auto Box<T, N>::max() const -> Coord const & { + return m_max; +} + +template <typename T, size_t N> +auto Box<T, N>::min() -> Coord & { + return m_min; +} + +template <typename T, size_t N> +auto Box<T, N>::max() -> Coord & { + return m_max; +} + +template <typename T, size_t N> +void Box<T, N>::setMin(Coord const& c) { + m_min = c; +} + +template <typename T, size_t N> +void Box<T, N>::setMax(Coord const& c) { + m_max = c; +} + +template <typename T, size_t N> +T Box<T, N>::volume() const { + return size().product(); +} + +template <typename T, size_t N> +auto Box<T, N>::overlap(Box const& b) const -> Box { + Box result = *this; + for (size_t i = 0; i < N; ++i) { + result.m_min[i] = std::max(result.m_min[i], b.m_min[i]); + result.m_max[i] = std::min(result.m_max[i], b.m_max[i]); + } + return result; +} + +template <typename T, size_t N> +auto Box<T, N>::intersection(Box const& b) const -> IntersectResult { + IntersectResult res; + + T overlap = std::numeric_limits<T>::max(); + size_t dim = 0; + bool negative = false; + for (size_t i = 0; i < N; ++i) { + if (m_max[i] - b.m_min[i] < overlap) { + overlap = m_max[i] - b.m_min[i]; + dim = i; + negative = true; + } + if (b.m_max[i] - m_min[i] < overlap) { + overlap = b.m_max[i] - m_min[i]; + dim = i; + negative = false; + } + } + + res.overlap = Coord(); + if (overlap > 0) { + res.intersects = true; + res.overlap[dim] = overlap; + } else { + res.intersects = false; + res.overlap[dim] = -overlap; + } + + if (negative) + res.overlap = -res.overlap; + + if (res.overlap == Coord()) { + res.glances = true; + } else { + res.glances = false; + } + + return res; +} + +template <typename T, size_t N> +bool Box<T, N>::intersects(Box const& b, bool includeEdges) const { + for (size_t i = 0; i < N; ++i) { + if (includeEdges) { + if (m_max[i] < b.m_min[i] || b.m_max[i] < m_min[i]) + return false; + } else { + if (m_max[i] <= b.m_min[i] || b.m_max[i] <= m_min[i]) + return false; + } + } + return true; +} + +template <typename T, size_t N> +bool Box<T, N>::contains(Coord const& p, bool includeEdges) const { + for (size_t i = 0; i < N; ++i) { + if (includeEdges) { + if (p[i] < m_min[i] || p[i] > m_max[i]) + return false; + } else { + if (p[i] <= m_min[i] || p[i] >= m_max[i]) + return false; + } + } + return true; +} + +template <typename T, size_t N> +bool Box<T, N>::contains(Box const& b, bool includeEdges) const { + return contains(b.min(), includeEdges) && contains(b.max(), includeEdges); +} + +template <typename T, size_t N> +bool Box<T, N>::belongs(Coord const& p) const { + for (size_t i = 0; i < N; ++i) { + if (p[i] < m_min[i] || p[i] >= m_max[i]) + return false; + } + + return true; +} + +template <typename T, size_t N> +bool Box<T, N>::containsEpsilon(Coord const& p, unsigned epsilons) const { + for (size_t i = 0; i < N; ++i) { + if (p[i] < m_min[i] || p[i] > m_max[i]) + return false; + if (nearEqual(p[i], m_min[i], epsilons) || nearEqual(p[i], m_max[i], epsilons)) + return false; + } + return true; +} + +template <typename T, size_t N> +bool Box<T, N>::containsEpsilon(Box const& b, unsigned epsilons) const { + return containsEpsilon(b.min(), epsilons) && containsEpsilon(b.max(), epsilons); +} + +template <typename T, size_t N> +bool Box<T, N>::operator==(Box const& ref) const { + return m_min == ref.m_min && m_max == ref.m_max; +} + +template <typename T, size_t N> +bool Box<T, N>::operator!=(Box const& ref) const { + return m_min != ref.m_min || m_max != ref.m_max; +} + +template <typename T, size_t N> +template <typename... TN> +void Box<T, N>::combineThings(Box& b, Coord const& point, TN const&... rest) { + b.combine(point); + combineThings(b, rest...); +} + +template <typename T, size_t N> +template <typename... TN> +void Box<T, N>::combineThings(Box& b, Box const& box, TN const&... rest) { + b.combine(box); + combineThings(b, rest...); +} + +template <typename T, size_t N> +template <typename... TN> +void Box<T, N>::combineThings(Box&) {} + +template <typename T, size_t N> +std::ostream& operator<<(std::ostream& os, Box<T, N> const& box) { + os << "Box{min:" << box.min() << " max:" << box.max() << "}"; + return os; +} + +template <typename T, size_t N> +template <size_t P, class> +Box<T, N>::Box(T minx, T miny, T maxx, T maxy) + : Box(Coord(minx, miny), Coord(maxx, maxy)) {} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::xMin() const -> Enable2D<P, T> { + return min()[0]; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::xMax() const -> Enable2D<P, T> { + return max()[0]; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::yMin() const -> Enable2D<P, T> { + return min()[1]; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::yMax() const -> Enable2D<P, T> { + return max()[1]; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::setXMin(T xMin) -> Enable2D<P> { + m_min[0] = xMin; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::setXMax(T xMax) -> Enable2D<P> { + m_max[0] = xMax; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::setYMin(T yMin) -> Enable2D<P> { + m_min[1] = yMin; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::setYMax(T yMax) -> Enable2D<P> { + m_max[1] = yMax; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::width() const -> Enable2D<P, T> { + return size(0); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::height() const -> Enable2D<P, T> { + return size(1); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::translate(T x, T y) -> Enable2D<P, void> { + translate(Coord(x, y)); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::translateToInclude(T x, T y, T xPadding, T yPadding) -> Enable2D<P, void> { + translateToInclude(Coord(x, y), Coord(xPadding, yPadding)); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::scale(T x, T y) -> Enable2D<P, void> { + scale(Coord(x, y)); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::expand(T x, T y) -> Enable2D<P, void> { + expand(Coord(x, y)); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::flipHorizontal() -> Enable2D<P, void> { + flip(0); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::flipVertical() -> Enable2D<P, void> { + flip(1); +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::edges() const -> Enable2D<P, Array<Line, 4>> { + Array<Line, 4> res; + res[0] = {min(), {min()[0], max()[1]}}; + res[1] = {min(), {max()[0], min()[1]}}; + res[2] = {{min()[0], max()[1]}, max()}; + res[3] = {{max()[0], min()[1]}, max()}; + return res; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::intersects(Line const& l) const -> Enable2D<P, bool> { + if (contains(l.min()) || contains(l.max())) + return true; + + for (auto i : edges()) { + if (l.intersects(i)) + return true; + } + return false; +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::intersectsCircle(Coord const& position, T radius) const -> Enable2D<P, bool> { + if (contains(position)) + return true; + for (auto const& e : edges()) { + if (e.distanceTo(position) <= radius) + return true; + } + return false; +} + +// returns the closest intersection point (from l.min()) +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::edgeIntersection(Line const& l) const -> Enable2D<P, LineIntersectResult> { + Array<LineIntersectResult, 4> candidates; + size_t numCandidates = 0; + + for (auto i : edges()) { + auto res = l.intersection(i); + if (res.intersects) + candidates[numCandidates++] = res; + } + + // How glancing is determined + // There are a few possibilities + // if candidates is empty then no intersection, easy + // if there is only one candidate then there are two possibilities, glancing + // or not + // But! if an endpoint is inside the rect, not just on the edge then it's + // false + // if there are two candidates and at least one of them is not glancing then + // false + // if there are two candidates and at they're both glancing then there's a few + // possibilities + // first, the line cuts through the corner, we can detect this by seeing if + // the point is in the + // box but not on the edge + // second, the line cuts across the corner, this case is true + // third, the line coincides with one of the sides, this case is also true. + // if there are 3 candidates then two cases + // first, the line coincides with one of the sides, and glances off of the + // other two, true + // second, the line cuts through a corner and reaches the far side, false + // we can tell these apart by determining if any intersections coincide + // if there are 4 candidates then the only possible case is false (cutting + // through both corners + if (numCandidates != 0) { + std::sort(candidates.ptr(), + candidates.ptr() + numCandidates, + [&](LineIntersectResult const& a, LineIntersectResult const& b) { return a.t < b.t; }); + if (numCandidates == 1) { + if (contains(l.min(), false) || contains(l.max(), false)) { + candidates[0].glances = false; + } + } else if (numCandidates == 2) { + if (contains(l.min(), false) || contains(l.max(), false)) { + candidates[0].glances = false; + } else if (contains(l.min()) && !candidates[1].glances) { + candidates[0].glances = false; + } + if (candidates[1].coincides) { // If we coincide on either consider it true + candidates[0].coincides = true; + } + } else if (numCandidates == 3) { + if (candidates[0].coincides || candidates[1].coincides || candidates[2].coincides) { + candidates[0].glances = true; + candidates[0].coincides = true; + } else { + candidates[0].glances = false; + } + } else { + candidates[0].glances = false; + candidates[0].coincides = false; + } + + return candidates[0]; + } else { + return LineIntersectResult(); + } +} + +template <typename T, size_t N> +template <size_t P> +auto Box<T, N>::subtract(Box const& rect) const -> Enable2D<P, List<Box>> { + List<Box> regions; + + auto overlap = Box::overlap(rect); + if (overlap.isEmpty()) { + // If this rect doesn't overlap at all with the subtracted one, obviously + // the entire rect is new territory. + regions.append(*this); + } else { + // If there is overlap with this rect, we need to add the left, bottom, + // right, and top sections. These can overlap at the corners, so the left + // and right sections will take the lower / upper left and lower / upper + // right corners, and the top and bottom will be limited to the width of + // the overlap section. + + if (xMin() < overlap.xMin()) + regions.append(Box(xMin(), yMin(), overlap.xMin(), yMax())); + + if (overlap.xMax() < xMax()) + regions.append(Box(overlap.xMax(), yMin(), xMax(), yMax())); + + if (yMin() < overlap.yMin()) + regions.append(Box(rect.xMin(), yMin(), rect.xMax(), overlap.yMin())); + + if (overlap.yMax() < yMax()) + regions.append(Box(rect.xMin(), overlap.yMax(), rect.xMax(), yMax())); + } + + return regions; +} + +template <typename T, size_t N> +auto Box<T, N>::nearestCoordTo(Coord const& c) const -> Coord { + Coord result = c; + for (size_t i = 0; i < N; ++i) + result[i] = clamp(result[i], m_min[i], m_max[i]); + return result; +} + +} + +#endif diff --git a/source/core/StarRefPtr.hpp b/source/core/StarRefPtr.hpp new file mode 100644 index 0000000..906d4ff --- /dev/null +++ b/source/core/StarRefPtr.hpp @@ -0,0 +1,303 @@ +#ifndef STAR_REF_PTR_HPP +#define STAR_REF_PTR_HPP + +#include "StarException.hpp" +#include "StarHash.hpp" + +namespace Star { + +// Reference counted ptr for intrusive reference counted types. Calls +// unqualified refPtrIncRef and refPtrDecRef functions to manage the reference +// count. +template <typename T> +class RefPtr { +public: + typedef T element_type; + + RefPtr(); + explicit RefPtr(T* p, bool addRef = true); + + RefPtr(RefPtr const& r); + RefPtr(RefPtr&& r); + + template <typename T2> + RefPtr(RefPtr<T2> const& r); + template <typename T2> + RefPtr(RefPtr<T2>&& r); + + ~RefPtr(); + + RefPtr& operator=(RefPtr const& r); + RefPtr& operator=(RefPtr&& r); + + template <typename T2> + RefPtr& operator=(RefPtr<T2> const& r); + template <typename T2> + RefPtr& operator=(RefPtr<T2>&& r); + + void reset(); + + void reset(T* r, bool addRef = true); + + T& operator*() const; + T* operator->() const; + T* get() const; + + explicit operator bool() const; + +private: + template <typename T2> + friend class RefPtr; + + T* m_ptr; +}; + +template <typename T, typename U> +bool operator==(RefPtr<T> const& a, RefPtr<U> const& b); + +template <typename T, typename U> +bool operator!=(RefPtr<T> const& a, RefPtr<U> const& b); + +template <typename T> +bool operator==(RefPtr<T> const& a, T* b); + +template <typename T> +bool operator!=(RefPtr<T> const& a, T* b); + +template <typename T> +bool operator==(T* a, RefPtr<T> const& b); + +template <typename T> +bool operator!=(T* a, RefPtr<T> const& b); + +template <typename T, typename U> +bool operator<(RefPtr<T> const& a, RefPtr<U> const& b); + +template <typename Type1, typename Type2> +bool is(RefPtr<Type2> const& p); + +template <typename Type1, typename Type2> +bool is(RefPtr<Type2 const> const& p); + +template <typename Type1, typename Type2> +RefPtr<Type1> as(RefPtr<Type2> const& p); + +template <typename Type1, typename Type2> +RefPtr<Type1 const> as(RefPtr<Type2 const> const& p); + +template <typename T, typename... Args> +RefPtr<T> make_ref(Args&&... args); + +template <typename T> +struct hash<RefPtr<T>> { + size_t operator()(RefPtr<T> const& a) const; + + hash<T*> hasher; +}; + +// Base class for RefPtr that is NOT thread safe. This can have a performance +// benefit over shared_ptr in single threaded contexts. +class RefCounter { +public: + friend void refPtrIncRef(RefCounter* p); + friend void refPtrDecRef(RefCounter* p); + +protected: + RefCounter(); + virtual ~RefCounter() = default; + +private: + size_t m_refCounter; +}; + +template <typename T> +RefPtr<T>::RefPtr() + : m_ptr(nullptr) {} + +template <typename T> +RefPtr<T>::RefPtr(T* p, bool addRef) + : m_ptr(nullptr) { + reset(p, addRef); +} + +template <typename T> +RefPtr<T>::RefPtr(RefPtr const& r) + : RefPtr(r.m_ptr) {} + +template <typename T> +RefPtr<T>::RefPtr(RefPtr&& r) { + m_ptr = r.m_ptr; + r.m_ptr = nullptr; +} + +template <typename T> +template <typename T2> +RefPtr<T>::RefPtr(RefPtr<T2> const& r) + : RefPtr(r.m_ptr) {} + +template <typename T> +template <typename T2> +RefPtr<T>::RefPtr(RefPtr<T2>&& r) { + m_ptr = r.m_ptr; + r.m_ptr = nullptr; +} + +template <typename T> +RefPtr<T>::~RefPtr() { + if (m_ptr) + refPtrDecRef(m_ptr); +} + +template <typename T> +RefPtr<T>& RefPtr<T>::operator=(RefPtr const& r) { + reset(r.m_ptr); + return *this; +} + +template <typename T> +RefPtr<T>& RefPtr<T>::operator=(RefPtr&& r) { + if (m_ptr) + refPtrDecRef(m_ptr); + + m_ptr = r.m_ptr; + r.m_ptr = nullptr; + return *this; +} + +template <typename T> +template <typename T2> +RefPtr<T>& RefPtr<T>::operator=(RefPtr<T2> const& r) { + reset(r.m_ptr); + return *this; +} + +template <typename T> +template <typename T2> +RefPtr<T>& RefPtr<T>::operator=(RefPtr<T2>&& r) { + if (m_ptr) + refPtrDecRef(m_ptr); + + m_ptr = r.m_ptr; + r.m_ptr = nullptr; + return *this; +} + +template <typename T> +void RefPtr<T>::reset() { + reset(nullptr); +} + +template <typename T> +void RefPtr<T>::reset(T* r, bool addRef) { + if (m_ptr == r) + return; + + if (m_ptr) + refPtrDecRef(m_ptr); + + m_ptr = r; + + if (m_ptr && addRef) + refPtrIncRef(m_ptr); +} + +template <typename T> +T& RefPtr<T>::operator*() const { + return *m_ptr; +} + +template <typename T> +T* RefPtr<T>::operator->() const { + return m_ptr; +} + +template <typename T> +T* RefPtr<T>::get() const { + return m_ptr; +} + +template <typename T> +RefPtr<T>::operator bool() const { + return m_ptr != nullptr; +} + +template <typename T, typename U> +bool operator==(RefPtr<T> const& a, RefPtr<U> const& b) { + return a.get() == b.get(); +} + +template <typename T, typename U> +bool operator!=(RefPtr<T> const& a, RefPtr<U> const& b) { + return a.get() != b.get(); +} + +template <typename T> +bool operator==(RefPtr<T> const& a, T* b) { + return a.get() == b; +} + +template <typename T> +bool operator!=(RefPtr<T> const& a, T* b) { + return a.get() != b; +} + +template <typename T> +bool operator==(T* a, RefPtr<T> const& b) { + return a == b.get(); +} + +template <typename T> +bool operator!=(T* a, RefPtr<T> const& b) { + return a != b.get(); +} + +template <typename T, typename U> +bool operator<(RefPtr<T> const& a, RefPtr<U> const& b) { + return a.get() < b.get(); +} + +template <typename Type1, typename Type2> +bool is(RefPtr<Type2> const& p) { + return (bool)dynamic_cast<Type1*>(p.get()); +} + +template <typename Type1, typename Type2> +bool is(RefPtr<Type2 const> const& p) { + return (bool)dynamic_cast<Type1 const*>(p.get()); +} + +template <typename Type1, typename Type2> +RefPtr<Type1> as(RefPtr<Type2> const& p) { + return RefPtr<Type1>(dynamic_cast<Type1*>(p.get())); +} + +template <typename Type1, typename Type2> +RefPtr<Type1 const> as(RefPtr<Type2 const> const& p) { + return RefPtr<Type1>(dynamic_cast<Type1 const*>(p.get())); +} + +template <typename T, typename... Args> +RefPtr<T> make_ref(Args&&... args) { + return RefPtr<T>(new T(forward<Args>(args)...)); +} + +template <typename T> +size_t hash<RefPtr<T>>::operator()(RefPtr<T> const& a) const { + return hasher(a.get()); +} + +inline void refPtrIncRef(RefCounter* p) { + ++p->m_refCounter; +} + +inline void refPtrDecRef(RefCounter* p) { + if (--p->m_refCounter == 0) + delete p; +} + +inline RefCounter::RefCounter() + : m_refCounter(0) {} + +} + +#endif diff --git a/source/core/StarRpcPromise.hpp b/source/core/StarRpcPromise.hpp new file mode 100644 index 0000000..dcfd242 --- /dev/null +++ b/source/core/StarRpcPromise.hpp @@ -0,0 +1,175 @@ +#ifndef STAR_RPC_PROMISE_HPP +#define STAR_RPC_PROMISE_HPP + +#include "StarEither.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_EXCEPTION(RpcPromiseException, StarException); + +// The other side of an RpcPromise, can be used to either fulfill or fail a +// paired promise. Call either fulfill or fail function exactly once, any +// further invocations will result in an exception. +template <typename Result, typename Error = String> +class RpcPromiseKeeper { +public: + void fulfill(Result result); + void fail(Error error); + +private: + template <typename ResultT, typename ErrorT> + friend class RpcPromise; + + function<void(Result)> m_fulfill; + function<void(Error)> m_fail; +}; + +// A generic promise for the result of a remote procedure call. It has +// reference semantics and is implicitly shared, but is not thread safe. +template <typename Result, typename Error = String> +class RpcPromise { +public: + static pair<RpcPromise, RpcPromiseKeeper<Result, Error>> createPair(); + static RpcPromise createFulfilled(Result result); + static RpcPromise createFailed(Error error); + + // Has the respoonse either failed or succeeded? + bool finished() const; + // Has the response finished with success? + bool succeeded() const; + // Has the response finished with failure? + bool failed() const; + + // Returns the result of the rpc call on success, nothing on failure or when + // not yet finished. + Maybe<Result> const& result() const; + + // Returns the error of a failed rpc call. Returns nothing if the call is + // successful or not yet finished. + Maybe<Error> const& error() const; + + // Wrap this RpcPromise into another promise which returns instead the result + // of this function when fulfilled + template <typename Function> + decltype(auto) wrap(Function function); + +private: + template <typename ResultT, typename ErrorT> + friend class RpcPromise; + + struct Value { + Maybe<Result> result; + Maybe<Error> error; + }; + + RpcPromise() = default; + + function<Value const*()> m_getValue; +}; + +template <typename Result, typename Error> +void RpcPromiseKeeper<Result, Error>::fulfill(Result result) { + m_fulfill(move(result)); +} + +template <typename Result, typename Error> +void RpcPromiseKeeper<Result, Error>::fail(Error error) { + m_fail(move(error)); +} + +template <typename Result, typename Error> +pair<RpcPromise<Result, Error>, RpcPromiseKeeper<Result, Error>> RpcPromise<Result, Error>::createPair() { + auto valuePtr = make_shared<Value>(); + + RpcPromise promise; + promise.m_getValue = [valuePtr]() { + return valuePtr.get(); + }; + + RpcPromiseKeeper<Result, Error> keeper; + keeper.m_fulfill = [valuePtr](Result result) { + if (valuePtr->result || valuePtr->error) + throw RpcPromiseException("fulfill called on already finished RpcPromise"); + valuePtr->result = move(result); + }; + keeper.m_fail = [valuePtr](Error error) { + if (valuePtr->result || valuePtr->error) + throw RpcPromiseException("fail called on already finished RpcPromise"); + valuePtr->error = move(error); + }; + + return {move(promise), move(keeper)}; +} + +template <typename Result, typename Error> +RpcPromise<Result, Error> RpcPromise<Result, Error>::createFulfilled(Result result) { + auto valuePtr = make_shared<Value>(); + valuePtr->result = move(result); + + RpcPromise<Result, Error> promise; + promise.m_getValue = [valuePtr]() { + return valuePtr.get(); + }; + return promise; +} + +template <typename Result, typename Error> +RpcPromise<Result, Error> RpcPromise<Result, Error>::createFailed(Error error) { + auto valuePtr = make_shared<Value>(); + valuePtr->error = move(error); + + RpcPromise<Result, Error> promise; + promise.m_getValue = [valuePtr]() { + return valuePtr.get(); + }; + return promise; +} + +template <typename Result, typename Error> +bool RpcPromise<Result, Error>::finished() const { + auto val = m_getValue(); + return val->result || val->error; +} + +template <typename Result, typename Error> +bool RpcPromise<Result, Error>::succeeded() const { + return m_getValue()->result.isValid(); +} + +template <typename Result, typename Error> +bool RpcPromise<Result, Error>::failed() const { + return m_getValue()->error.isValid(); +} + +template <typename Result, typename Error> +Maybe<Result> const& RpcPromise<Result, Error>::result() const { + return m_getValue()->result; +} + +template <typename Result, typename Error> +Maybe<Error> const& RpcPromise<Result, Error>::error() const { + return m_getValue()->error; +} + +template <typename Result, typename Error> +template <typename Function> +decltype(auto) RpcPromise<Result, Error>::wrap(Function function) { + typedef RpcPromise<typename std::decay<decltype(function(std::declval<Result>()))>::type, Error> WrappedPromise; + WrappedPromise wrappedPromise; + wrappedPromise.m_getValue = [wrapper = move(function), valuePtr = make_shared<typename WrappedPromise::Value>(), otherGetValue = m_getValue]() { + if (!valuePtr->result && !valuePtr->error) { + auto otherValue = otherGetValue(); + if (otherValue->result) + valuePtr->result.set(wrapper(*otherValue->result)); + else if (otherValue->error) + valuePtr->error.set(*otherValue->error); + } + return valuePtr.get(); + }; + return wrappedPromise; +} + +} + +#endif diff --git a/source/core/StarSectorArray2D.hpp b/source/core/StarSectorArray2D.hpp new file mode 100644 index 0000000..4645804 --- /dev/null +++ b/source/core/StarSectorArray2D.hpp @@ -0,0 +1,378 @@ +#ifndef STAR_SECTOR_SET_HPP +#define STAR_SECTOR_SET_HPP + +#include "StarMultiArray.hpp" +#include "StarSet.hpp" +#include "StarVector.hpp" + +namespace Star { + +// Holds a sparse 2d array of data based on sector size. Meant to be used as a +// fast-as-possible sparse array. Memory requiremenets are equal to the size +// of all loaded sectors PLUS pointer size * sectors wide * sectors high +template <typename ElementT, size_t SectorSize> +class SectorArray2D { +public: + typedef ElementT Element; + typedef Vec2S Sector; + + struct SectorRange { + // Lower left sector + Vec2S min; + // Upper right sector *non-inclusive* + Vec2S max; + }; + + struct Array { + Array(); + Array(Element const& def); + + Element const& operator()(size_t x, size_t y) const; + Element& operator()(size_t x, size_t y); + + Element elements[SectorSize * SectorSize]; + }; + typedef unique_ptr<Array> ArrayPtr; + + typedef MultiArray<Element, 2> DynamicArray; + + SectorArray2D(); + SectorArray2D(size_t numSectorsWide, size_t numSectorsHigh); + + void init(size_t numSectorsWide, size_t numSectorsHigh); + + // Total size of array elements + size_t width() const; + size_t height() const; + + // Is sector within width() and heigh() + bool sectorValid(Sector const& sector) const; + + // Returns the sector that contains the given point + Sector sectorFor(size_t x, size_t y) const; + // Returns the sector range that contains the given rectangle + SectorRange sectorRange(size_t minX, size_t minY, size_t width, size_t height) const; + + Vec2S sectorCorner(Sector const& id) const; + bool hasSector(Sector const& id) const; + + List<Sector> loadedSectors() const; + size_t loadedSectorCount() const; + bool sectorLoaded(Sector const& id) const; + + // Will return nullptr if sector is not loaded. + Array* sector(Sector const& id); + Array const* sector(Sector const& id) const; + + void loadSector(Sector const& id, ArrayPtr array); + ArrayPtr copySector(Sector const& id); + ArrayPtr takeSector(Sector const& id); + void discardSector(Sector const& id); + + // Will return nullptr if sector is not loaded. + Element const* get(size_t x, size_t y) const; + Element* get(size_t x, size_t y); + + // Fast evaluate of elements in the given range. If evalEmpty is true, then + // function will be called even for unloaded sectors (with null pointer). + // Function is called as function(size_t x, size_t y, Element* element). + // Given function should return true to continue, false to stop. Returns + // false if any evaled functions return false. + template <typename Function> + bool eval(size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty = false) const; + template <typename Function> + bool eval(size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty = false); + + // Individual sectors are stored column-major, so for speed, use this method + // to get whole columns at a time. If eval empty is true, function will be + // called with for each empty column with the correct size information, but + // the pointer will be null. Function will be called as + // function(size_t x, size_t y, Element* columnElements, size_t columnSize) + // columnSize is guaranteed never to be greater than SectorSize. Given + // function should return true to continue, false to stop. Returns false if + // any evaled columns return false. + template <typename Function> + bool evalColumns( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty = false) const; + template <typename Function> + bool evalColumns(size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty = false); + +private: + typedef MultiArray<ArrayPtr, 2> SectorArray; + + template <typename Function> + bool evalPriv(size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty); + template <typename Function> + bool evalColumnsPriv(size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty); + + SectorArray m_sectors; + HashSet<Sector> m_loadedSectors; +}; + +template <typename ElementT, size_t SectorSize> +SectorArray2D<ElementT, SectorSize>::Array::Array() + : elements() {} + +template <typename ElementT, size_t SectorSize> +SectorArray2D<ElementT, SectorSize>::Array::Array(Element const& def) { + for (size_t i = 0; i < SectorSize * SectorSize; ++i) + elements[i] = def; +} + +template <typename ElementT, size_t SectorSize> +ElementT const& SectorArray2D<ElementT, SectorSize>::Array::operator()(size_t x, size_t y) const { + starAssert(x < SectorSize && y < SectorSize); + return elements[x * SectorSize + y]; +} + +template <typename ElementT, size_t SectorSize> +ElementT& SectorArray2D<ElementT, SectorSize>::Array::operator()(size_t x, size_t y) { + starAssert(x < SectorSize && y < SectorSize); + return elements[x * SectorSize + y]; +} + +template <typename ElementT, size_t SectorSize> +SectorArray2D<ElementT, SectorSize>::SectorArray2D() {} + +template <typename ElementT, size_t SectorSize> +SectorArray2D<ElementT, SectorSize>::SectorArray2D(size_t numSectorsWide, size_t numSectorsHigh) { + init(numSectorsWide, numSectorsHigh); +} + +template <typename ElementT, size_t SectorSize> +void SectorArray2D<ElementT, SectorSize>::init(size_t numSectorsWide, size_t numSectorsHigh) { + m_sectors.clear(); + m_sectors.setSize(numSectorsWide, numSectorsHigh); + m_loadedSectors.clear(); +} + +template <typename ElementT, size_t SectorSize> +size_t SectorArray2D<ElementT, SectorSize>::width() const { + return m_sectors.size(0) * SectorSize; +} + +template <typename ElementT, size_t SectorSize> +size_t SectorArray2D<ElementT, SectorSize>::height() const { + return m_sectors.size(1) * SectorSize; +} + +template <typename ElementT, size_t SectorSize> +bool SectorArray2D<ElementT, SectorSize>::sectorValid(Sector const& sector) const { + return sector[0] < m_sectors.size(0) && sector[1] < m_sectors.size(1); +} + +template <typename ElementT, size_t SectorSize> +auto SectorArray2D<ElementT, SectorSize>::sectorFor(size_t x, size_t y) const -> Sector { + return {x / SectorSize, y / SectorSize}; +} + +template <typename ElementT, size_t SectorSize> +auto SectorArray2D<ElementT, SectorSize>::sectorRange(size_t minX, size_t minY, size_t width, size_t height) const -> SectorRange { + return { + {minX / SectorSize, minY / SectorSize}, + {(minX + width + SectorSize - 1) / SectorSize, (minY + height + SectorSize - 1) / SectorSize} + }; +} + +template <typename ElementT, size_t SectorSize> +Vec2S SectorArray2D<ElementT, SectorSize>::sectorCorner(Sector const& id) const { + return Vec2S(id[0] * SectorSize, id[1] * SectorSize); +} + +template <typename ElementT, size_t SectorSize> +bool SectorArray2D<ElementT, SectorSize>::hasSector(Sector const& id) const { + starAssert(id[0] < m_sectors.size(0) && id[1] < m_sectors.size(1)); + return (bool)m_sectors(id[0], id[1]); +} + +template <typename ElementT, size_t SectorSize> +auto SectorArray2D<ElementT, SectorSize>::loadedSectors() const -> List<Sector> { + return m_loadedSectors.values(); +} + +template <typename ElementT, size_t SectorSize> +size_t SectorArray2D<ElementT, SectorSize>::loadedSectorCount() const { + return m_loadedSectors.size(); +} + +template <typename ElementT, size_t SectorSize> +bool SectorArray2D<ElementT, SectorSize>::sectorLoaded(Sector const& id) const { + return m_loadedSectors.contains(id); +} + +template <typename ElementT, size_t SectorSize> +auto SectorArray2D<ElementT, SectorSize>::sector(Sector const& id) -> Array * { + return m_sectors(id[0], id[1]).get(); +} + +template <typename ElementT, size_t SectorSize> +auto SectorArray2D<ElementT, SectorSize>::sector(Sector const& id) const -> Array const * { + return m_sectors(id[0], id[1]).get(); +} + +template <typename ElementT, size_t SectorSize> +void SectorArray2D<ElementT, SectorSize>::loadSector(Sector const& id, ArrayPtr array) { + auto& data = m_sectors(id[0], id[1]); + data = move(array); + if (data) + m_loadedSectors.add(id); + else + m_loadedSectors.remove(id); +} + +template <typename ElementT, size_t SectorSize> +typename SectorArray2D<ElementT, SectorSize>::ArrayPtr SectorArray2D<ElementT, SectorSize>::copySector( + Sector const& id) { + if (auto const& array = m_sectors(id)) + return make_unique<Array>(*array); + else + return {}; +} + +template <typename ElementT, size_t SectorSize> +typename SectorArray2D<ElementT, SectorSize>::ArrayPtr SectorArray2D<ElementT, SectorSize>::takeSector( + Sector const& id) { + ArrayPtr ret; + m_loadedSectors.remove(id); + std::swap(m_sectors(id[0], id[1]), ret); + return move(ret); +} + +template <typename ElementT, size_t SectorSize> +void SectorArray2D<ElementT, SectorSize>::discardSector(Sector const& id) { + m_loadedSectors.remove(id); + m_sectors(id[0], id[1]).reset(); +} + +template <typename ElementT, size_t SectorSize> +typename SectorArray2D<ElementT, SectorSize>::Element const* SectorArray2D<ElementT, SectorSize>::get( + size_t x, size_t y) const { + Array* array = m_sectors(x / SectorSize, y / SectorSize).get(); + if (array) { + return &(*array)(x % SectorSize, y % SectorSize); + } else { + return nullptr; + } +} + +template <typename ElementT, size_t SectorSize> +typename SectorArray2D<ElementT, SectorSize>::Element* SectorArray2D<ElementT, SectorSize>::get(size_t x, size_t y) { + Array* array = m_sectors(x / SectorSize, y / SectorSize).get(); + if (array) + return &(*array)(x % SectorSize, y % SectorSize); + else + return nullptr; +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::eval( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) const { + return const_cast<SectorArray2D*>(this)->evalPriv(minX, minY, width, height, forward<Function>(function), evalEmpty); +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::eval( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) { + return evalPriv(minX, minY, width, height, forward<Function>(function), evalEmpty); +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::evalColumns( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) const { + return const_cast<SectorArray2D*>(this)->evalColumnsPriv( + minX, minY, width, height, forward<Function>(function), evalEmpty); +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::evalColumns( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) { + return evalColumnsPriv(minX, minY, width, height, forward<Function>(function), evalEmpty); +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::evalPriv( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) { + return evalColumnsPriv(minX, + minY, + width, + height, + [&function](size_t x, size_t y, Element* column, size_t columnSize) { + for (size_t i = 0; i < columnSize; ++i) { + if (column) { + if (!function(x, y + i, column + i)) + return false; + } else { + if (!function(x, y + i, nullptr)) + return false; + } + } + return true; + }, + evalEmpty); +} + +template <typename ElementT, size_t SectorSize> +template <typename Function> +bool SectorArray2D<ElementT, SectorSize>::evalColumnsPriv( + size_t minX, size_t minY, size_t width, size_t height, Function&& function, bool evalEmpty) { + if (width == 0 || height == 0) + return true; + + size_t maxX = minX + width; + size_t maxY = minY + height; + size_t minXSector = minX / SectorSize; + size_t maxXSector = (maxX - 1) / SectorSize; + + size_t minYSector = minY / SectorSize; + size_t maxYSector = (maxY - 1) / SectorSize; + + for (size_t xSector = minXSector; xSector <= maxXSector; ++xSector) { + size_t minXi = 0; + if (xSector == minXSector) + minXi = minX % SectorSize; + + size_t maxXi = SectorSize - 1; + if (xSector == maxXSector) + maxXi = (maxX - 1) % SectorSize; + + for (size_t ySector = minYSector; ySector <= maxYSector; ++ySector) { + Array* array = m_sectors(xSector, ySector).get(); + + if (!array && !evalEmpty) + continue; + + size_t minYi = 0; + if (ySector == minYSector) + minYi = minY % SectorSize; + + size_t maxYi = SectorSize - 1; + if (ySector == maxYSector) + maxYi = (maxY - 1) % SectorSize; + + size_t y_ = ySector * SectorSize; + size_t x_ = xSector * SectorSize; + if (!array) { + for (size_t xi = minXi; xi <= maxXi; ++xi) { + if (!function(xi + x_, minYi + y_, nullptr, maxYi - minYi + 1)) + return false; + } + } else { + for (size_t xi = minXi; xi <= maxXi; ++xi) { + if (!function(xi + x_, minYi + y_, &array->elements[xi * SectorSize + minYi], maxYi - minYi + 1)) + return false; + } + } + } + } + + return true; +} + +} + +#endif diff --git a/source/core/StarSecureRandom.hpp b/source/core/StarSecureRandom.hpp new file mode 100644 index 0000000..e4110b0 --- /dev/null +++ b/source/core/StarSecureRandom.hpp @@ -0,0 +1,14 @@ +#ifndef STAR_SECURE_RANDOM_HPP +#define STAR_SECURE_RANDOM_HPP + +#include "StarByteArray.hpp" + +namespace Star { + +// Generate cryptographically secure random numbers for usage in password salts +// and such using OS facilities +ByteArray secureRandomBytes(size_t size); + +} + +#endif diff --git a/source/core/StarSecureRandom_unix.cpp b/source/core/StarSecureRandom_unix.cpp new file mode 100644 index 0000000..4ddb337 --- /dev/null +++ b/source/core/StarSecureRandom_unix.cpp @@ -0,0 +1,10 @@ +#include "StarSecureRandom.hpp" +#include "StarFile.hpp" + +namespace Star { + +ByteArray secureRandomBytes(size_t size) { + return File::open("/dev/urandom", IOMode::Read)->readBytes(size); +} + +} diff --git a/source/core/StarSecureRandom_windows.cpp b/source/core/StarSecureRandom_windows.cpp new file mode 100644 index 0000000..e118991 --- /dev/null +++ b/source/core/StarSecureRandom_windows.cpp @@ -0,0 +1,21 @@ +#include "StarSecureRandom.hpp" +#include <windows.h> +#include <wincrypt.h> + +namespace Star { + +ByteArray secureRandomBytes(size_t size) { + HCRYPTPROV context = 0; + auto res = ByteArray(size, '\0'); + + CryptAcquireContext(&context, 0, MS_DEF_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT); + auto success = CryptGenRandom(context, size, (PBYTE)res.ptr()); + CryptReleaseContext(context, 0); + + if (!success) + throw StarException("Could not read random bytes from source."); + + return res; +} + +} diff --git a/source/core/StarSet.hpp b/source/core/StarSet.hpp new file mode 100644 index 0000000..8950e9c --- /dev/null +++ b/source/core/StarSet.hpp @@ -0,0 +1,322 @@ +#ifndef STAR_SET_HPP +#define STAR_SET_HPP + +#include <set> +#include <unordered_set> + +#include "StarFlatHashSet.hpp" +#include "StarList.hpp" + +namespace Star { + +STAR_EXCEPTION(SetException, StarException); + +template <typename BaseSet> +class SetMixin : public BaseSet { +public: + typedef BaseSet Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + + typedef typename Base::value_type value_type; + + using Base::Base; + + List<value_type> values() const; + + bool contains(value_type const& v) const; + + bool add(value_type const& v); + + // Like add, but always adds new value, potentially replacing another equal + // (comparing equal, may not be actually equal) value. Returns whether an + // existing value was replaced. + bool replace(value_type v); + + template <typename Container> + void addAll(Container const& s); + + bool remove(value_type const& v); + + template <typename Container> + void removeAll(Container const& s); + + value_type first(); + Maybe<value_type> maybeFirst(); + value_type takeFirst(); + Maybe<value_type> maybeTakeFirst(); + + value_type last(); + Maybe<value_type> maybeLast(); + value_type takeLast(); + Maybe<value_type> maybeTakeLast(); + + bool hasIntersection(SetMixin const& s) const; +}; + +template <typename BaseSet> +std::ostream& operator<<(std::ostream& os, SetMixin<BaseSet> const& set); + +template <typename Value, typename Compare = std::less<Value>, typename Allocator = std::allocator<Value>> +class Set : public SetMixin<std::set<Value, Compare, Allocator>> { +public: + typedef SetMixin<std::set<Value, Compare, Allocator>> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + + typedef typename Base::value_type value_type; + + template <typename Container> + static Set from(Container const& c); + + using Base::Base; + + // Returns set of elements that are in this set and the given set. + Set intersection(Set const& s) const; + Set intersection(Set const& s, std::function<bool(Value const&, Value const&)> compare) const; + + // Returns elements in this set that are not in the given set + Set difference(Set const& s) const; + Set difference(Set const& s, std::function<bool(Value const&, Value const&)> compare) const; + + // Returns elements in either this set or the given set + Set combination(Set const& s) const; +}; + +template <typename BaseSet> +class HashSetMixin : public SetMixin<BaseSet> { +public: + typedef SetMixin<BaseSet> Base; + + typedef typename Base::iterator iterator; + typedef typename Base::const_iterator const_iterator; + + typedef typename Base::value_type value_type; + + template <typename Container> + static HashSetMixin from(Container const& c); + + using Base::Base; + + HashSetMixin intersection(HashSetMixin const& s) const; + HashSetMixin difference(HashSetMixin const& s) const; + HashSetMixin combination(HashSetMixin const& s) const; +}; + +template <typename Value, typename Hash = hash<Value>, typename Equals = std::equal_to<Value>, typename Allocator = std::allocator<Value>> +using HashSet = HashSetMixin<FlatHashSet<Value, Hash, Equals, Allocator>>; + +template <typename Value, typename Hash = hash<Value>, typename Equals = std::equal_to<Value>, typename Allocator = std::allocator<Value>> +using StableHashSet = HashSetMixin<std::unordered_set<Value, Hash, Equals, Allocator>>; + +template <typename BaseSet> +auto SetMixin<BaseSet>::values() const -> List<value_type> { + return List<value_type>(Base::begin(), Base::end()); +} + +template <typename BaseSet> +bool SetMixin<BaseSet>::contains(value_type const& v) const { + return Base::find(v) != Base::end(); +} + +template <typename BaseSet> +bool SetMixin<BaseSet>::add(value_type const& v) { + return Base::insert(v).second; +} + +template <typename BaseSet> +bool SetMixin<BaseSet>::replace(value_type v) { + bool replaced = remove(v); + Base::insert(move(v)); + return replaced; +} + +template <typename BaseSet> +template <typename Container> +void SetMixin<BaseSet>::addAll(Container const& s) { + return Base::insert(s.begin(), s.end()); +} + +template <typename BaseSet> +bool SetMixin<BaseSet>::remove(value_type const& v) { + return Base::erase(v) != 0; +} + +template <typename BaseSet> +template <typename Container> +void SetMixin<BaseSet>::removeAll(Container const& s) { + for (auto const& v : s) + remove(v); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::first() -> value_type { + if (Base::empty()) + throw SetException("first called on empty set"); + return *Base::begin(); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::maybeFirst() -> Maybe<value_type> { + if (Base::empty()) + return {}; + return *Base::begin(); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::takeFirst() -> value_type { + if (Base::empty()) + throw SetException("takeFirst called on empty set"); + auto i = Base::begin(); + value_type v = move(*i); + Base::erase(i); + return v; +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::maybeTakeFirst() -> Maybe<value_type> { + if (Base::empty()) + return {}; + auto i = Base::begin(); + value_type v = move(*i); + Base::erase(i); + return move(v); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::last() -> value_type { + if (Base::empty()) + throw SetException("last called on empty set"); + return *prev(Base::end()); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::maybeLast() -> Maybe<value_type> { + if (Base::empty()) + return {}; + return *prev(Base::end()); +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::takeLast() -> value_type { + if (Base::empty()) + throw SetException("takeLast called on empty set"); + auto i = prev(Base::end()); + value_type v = move(*i); + Base::erase(i); + return v; +} + +template <typename BaseSet> +auto SetMixin<BaseSet>::maybeTakeLast() -> Maybe<value_type> { + if (Base::empty()) + return {}; + auto i = prev(Base::end()); + value_type v = move(*i); + Base::erase(i); + return move(v); +} + +template <typename BaseSet> +bool SetMixin<BaseSet>::hasIntersection(SetMixin const& s) const { + for (auto const& v : s) { + if (contains(v)) { + return true; + } + } + return false; +} + +template <typename BaseSet> +std::ostream& operator<<(std::ostream& os, SetMixin<BaseSet> const& set) { + os << "("; + for (auto i = set.begin(); i != set.end(); ++i) { + if (i != set.begin()) + os << ", "; + os << *i; + } + os << ")"; + return os; +} + +template <typename Value, typename Compare, typename Allocator> +template <typename Container> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::from(Container const& c) { + return Set(c.begin(), c.end()); +} + +template <typename Value, typename Compare, typename Allocator> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::intersection(Set const& s) const { + Set res; + std::set_intersection(Base::begin(), Base::end(), s.begin(), s.end(), std::inserter(res, res.end())); + return res; +} + +template <typename Value, typename Compare, typename Allocator> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::intersection(Set const& s, std::function<bool(Value const&, Value const&)> compare) const { + Set res; + std::set_intersection(Base::begin(), Base::end(), s.begin(), s.end(), std::inserter(res, res.end()), compare); + return res; +} + +template <typename Value, typename Compare, typename Allocator> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::difference(Set const& s) const { + Set res; + std::set_difference(Base::begin(), Base::end(), s.begin(), s.end(), std::inserter(res, res.end())); + return res; +} + +template <typename Value, typename Compare, typename Allocator> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::difference(Set const& s, std::function<bool(Value const&, Value const&)> compare) const { + Set res; + std::set_difference(Base::begin(), Base::end(), s.begin(), s.end(), std::inserter(res, res.end()), compare); + return res; +} + +template <typename Value, typename Compare, typename Allocator> +Set<Value, Compare, Allocator> Set<Value, Compare, Allocator>::combination(Set const& s) const { + Set ret(*this); + ret.addAll(s); + return ret; +} + +template <typename BaseMap> +template <typename Container> +HashSetMixin<BaseMap> HashSetMixin<BaseMap>::from(Container const& c) { + return HashSetMixin(c.begin(), c.end()); +} + +template <typename BaseMap> +HashSetMixin<BaseMap> HashSetMixin<BaseMap>::intersection(HashSetMixin const& s) const { + // Can't use std::set_intersection, since not sorted, naive version is fine. + HashSetMixin ret; + for (auto const& e : s) { + if (contains(e)) + ret.add(e); + } + return ret; +} + +template <typename BaseMap> +HashSetMixin<BaseMap> HashSetMixin<BaseMap>::difference(HashSetMixin const& s) const { + // Can't use std::set_difference, since not sorted, naive version is fine. + HashSetMixin ret; + for (auto const& e : *this) { + if (!s.contains(e)) + ret.add(e); + } + return ret; +} + +template <typename BaseMap> +HashSetMixin<BaseMap> HashSetMixin<BaseMap>::combination(HashSetMixin const& s) const { + HashSetMixin ret(*this); + ret.addAll(s); + return ret; +} + +} + +#endif diff --git a/source/core/StarSha256.cpp b/source/core/StarSha256.cpp new file mode 100644 index 0000000..c63c32a --- /dev/null +++ b/source/core/StarSha256.cpp @@ -0,0 +1,260 @@ +#include "StarSha256.hpp" +#include "StarFormat.hpp" +#include "StarEncode.hpp" + +namespace Star { + +// An implementation of the SHA-256 hash function, this is endian neutral +// so should work just about anywhere. +// +// This code works much like the MD5 code provided by RSA. You sha_init() +// a "sha_state" then sha_process() the bytes you want and sha_done() to get +// the output. +// +// Revised Code: Complies to SHA-256 standard now. +// +// Tom St Denis + +// the K array +static const uint32_t K[64] = {0x428a2f98U, + 0x71374491U, + 0xb5c0fbcfU, + 0xe9b5dba5U, + 0x3956c25bU, + 0x59f111f1U, + 0x923f82a4U, + 0xab1c5ed5U, + 0xd807aa98U, + 0x12835b01U, + 0x243185beU, + 0x550c7dc3U, + 0x72be5d74U, + 0x80deb1feU, + 0x9bdc06a7U, + 0xc19bf174U, + 0xe49b69c1U, + 0xefbe4786U, + 0x0fc19dc6U, + 0x240ca1ccU, + 0x2de92c6fU, + 0x4a7484aaU, + 0x5cb0a9dcU, + 0x76f988daU, + 0x983e5152U, + 0xa831c66dU, + 0xb00327c8U, + 0xbf597fc7U, + 0xc6e00bf3U, + 0xd5a79147U, + 0x06ca6351U, + 0x14292967U, + 0x27b70a85U, + 0x2e1b2138U, + 0x4d2c6dfcU, + 0x53380d13U, + 0x650a7354U, + 0x766a0abbU, + 0x81c2c92eU, + 0x92722c85U, + 0xa2bfe8a1U, + 0xa81a664bU, + 0xc24b8b70U, + 0xc76c51a3U, + 0xd192e819U, + 0xd6990624U, + 0xf40e3585U, + 0x106aa070U, + 0x19a4c116U, + 0x1e376c08U, + 0x2748774cU, + 0x34b0bcb5U, + 0x391c0cb3U, + 0x4ed8aa4aU, + 0x5b9cca4fU, + 0x682e6ff3U, + 0x748f82eeU, + 0x78a5636fU, + 0x84c87814U, + 0x8cc70208U, + 0x90befffaU, + 0xa4506cebU, + 0xbef9a3f7U, + 0xc67178f2UL}; + +// Various logical functions +#define Ch(x, y, z) ((x & y) ^ (~x & z)) +#define Maj(x, y, z) ((x & y) ^ (x & z) ^ (y & z)) +#define S(x, n) (((x) >> ((n)&31)) | ((x) << (32 - ((n)&31)))) +#define R(x, n) ((x) >> (n)) +#define Sigma0(x) (S(x, 2) ^ S(x, 13) ^ S(x, 22)) +#define Sigma1(x) (S(x, 6) ^ S(x, 11) ^ S(x, 25)) +#define Gamma0(x) (S(x, 7) ^ S(x, 18) ^ R(x, 3)) +#define Gamma1(x) (S(x, 17) ^ S(x, 19) ^ R(x, 10)) + +// compress 512-bits +static void sha_compress(sha_state* md) { + uint32_t S[8], W[64], t0, t1; + int i; + + /* copy state into S */ + for (i = 0; i < 8; i++) + S[i] = md->state[i]; + + /* copy the state into 512-bits into W[0..15] */ + for (i = 0; i < 16; i++) + W[i] = (((uint32_t)md->buf[(4 * i) + 0]) << 24) | (((uint32_t)md->buf[(4 * i) + 1]) << 16) + | (((uint32_t)md->buf[(4 * i) + 2]) << 8) | (((uint32_t)md->buf[(4 * i) + 3])); + + /* fill W[16..63] */ + for (i = 16; i < 64; i++) + W[i] = Gamma1(W[i - 2]) + W[i - 7] + Gamma0(W[i - 15]) + W[i - 16]; + + /* Compress */ + for (i = 0; i < 64; i++) { + t0 = S[7] + Sigma1(S[4]) + Ch(S[4], S[5], S[6]) + K[i] + W[i]; + t1 = Sigma0(S[0]) + Maj(S[0], S[1], S[2]); + S[7] = S[6]; + S[6] = S[5]; + S[5] = S[4]; + S[4] = S[3] + t0; + S[3] = S[2]; + S[2] = S[1]; + S[1] = S[0]; + S[0] = t0 + t1; + } + + /* feedback */ + for (i = 0; i < 8; i++) + md->state[i] += S[i]; +} + +// init the SHA state +static void sha_init(sha_state* md) { + md->curlen = md->length = 0; + md->state[0] = 0x6A09E667U; + md->state[1] = 0xBB67AE85U; + md->state[2] = 0x3C6EF372U; + md->state[3] = 0xA54FF53AU; + md->state[4] = 0x510E527FU; + md->state[5] = 0x9B05688CU; + md->state[6] = 0x1F83D9ABU; + md->state[7] = 0x5BE0CD19U; +} + +static void sha_process(sha_state* md, uint8_t* buf, int len) { + while (len--) { + /* copy byte */ + md->buf[md->curlen++] = *buf++; + + /* is 64 bytes full? */ + if (md->curlen == 64) { + sha_compress(md); + md->length += 512; + md->curlen = 0; + } + } +} + +static void sha_done(sha_state* md, uint8_t* hash) { + int i; + + /* increase the length of the message */ + md->length += md->curlen * 8; + + /* append the '1' bit */ + md->buf[md->curlen++] = 0x80; + + /* if the length is currently above 56 bytes we append zeros then compress. + Then we can fall back to padding zeros and length encoding like normal. */ + + if (md->curlen > 56) { + for (; md->curlen < 64;) + md->buf[md->curlen++] = 0; + sha_compress(md); + md->curlen = 0; + } + + /* pad upto 56 bytes of zeroes */ + for (; md->curlen < 56;) + md->buf[md->curlen++] = 0; + + /* since all messages are under 2^32 bits we mark the top bits zero */ + for (i = 56; i < 60; i++) + md->buf[i] = 0; + + /* append length */ + for (i = 60; i < 64; i++) + md->buf[i] = (md->length >> ((63 - i) * 8)) & 255; + sha_compress(md); + + /* copy output */ + for (i = 0; i < 32; i++) + hash[i] = (md->state[i >> 2] >> (((3 - i) & 3) << 3)) & 255; +} + +Sha256Hasher::Sha256Hasher() { + m_finished = false; + sha_init(&m_state); +} + +void Sha256Hasher::push(char const* data, size_t length) { + if (m_finished) { + sha_init(&m_state); + m_finished = false; + } + + sha_process(&m_state, (uint8_t*)data, length); +} + +void Sha256Hasher::push(String const& data) { + push(data.utf8Ptr(), data.utf8Size()); +} + +void Sha256Hasher::push(ByteArray const& data) { + push(data.ptr(), data.size()); +} + +ByteArray Sha256Hasher::compute() { + ByteArray dest(32, 0); + sha_done(&m_state, (uint8_t*)dest.ptr()); + m_finished = true; + return dest; +} + +void Sha256Hasher::compute(char* hashDestination) { + sha_done(&m_state, (uint8_t*)hashDestination); + m_finished = true; +} + +void sha256(char const* source, size_t length, char* hashDestination) { + sha_state state; + sha_init(&state); + sha_process(&state, (uint8_t*)source, length); + sha_done(&state, (uint8_t*)hashDestination); +} + +ByteArray sha256(char const* source, size_t length) { + ByteArray dest(32, 0); + sha256(source, length, dest.ptr()); + return dest; +} + +void sha256(ByteArray const& in, ByteArray& out) { + out.resize(32, 0); + sha256(in.ptr(), in.size(), out.ptr()); +} + +void sha256(String const& in, ByteArray& out) { + out.resize(32, 0); + sha256(in.utf8Ptr(), in.utf8Size(), out.ptr()); +} + +ByteArray sha256(ByteArray const& in) { + return sha256(in.ptr(), in.size()); +} + +ByteArray sha256(String const& in) { + return sha256(in.utf8Ptr(), in.utf8Size()); +} + +} diff --git a/source/core/StarSha256.hpp b/source/core/StarSha256.hpp new file mode 100644 index 0000000..6a9abd4 --- /dev/null +++ b/source/core/StarSha256.hpp @@ -0,0 +1,44 @@ +#ifndef STAR_SHA_256_HPP +#define STAR_SHA_256_HPP + +#include "StarString.hpp" +#include "StarByteArray.hpp" + +namespace Star { + +typedef struct sha_state_struct { + uint32_t state[8], length, curlen; + uint8_t buf[64]; +} sha_state; + +class Sha256Hasher { +public: + Sha256Hasher(); + + void push(char const* data, size_t length); + void push(String const& data); + void push(ByteArray const& data); + + // Produces 32 bytes + void compute(char* hashDestination); + ByteArray compute(); + +private: + bool m_finished; + sha_state m_state; +}; + +// Sha256 must, obviously, have 32 bytes available in the destination. +void sha256(char const* source, size_t length, char* hashDestination); + +ByteArray sha256(char const* source, size_t length); + +void sha256(ByteArray const& in, ByteArray& out); +void sha256(String const& in, ByteArray& out); + +ByteArray sha256(ByteArray const& in); +ByteArray sha256(String const& in); + +} + +#endif diff --git a/source/core/StarShellParser.cpp b/source/core/StarShellParser.cpp new file mode 100644 index 0000000..31a927e --- /dev/null +++ b/source/core/StarShellParser.cpp @@ -0,0 +1,208 @@ +#include "StarShellParser.hpp" + +namespace Star { + +ShellParser::ShellParser() + : m_current(), m_end(), m_quotedType('\0') {} + +auto ShellParser::tokenize(String const& command) -> List<Token> { + List<Token> res; + + init(command); + + while (notDone()) { + res.append(Token{TokenType::Word, word()}); + } + + return res; +} + +StringList ShellParser::tokenizeToStringList(String const& command) { + StringList res; + for (auto token : tokenize(command)) { + if (token.type == TokenType::Word) { + res.append(move(token.token)); + } + } + + return res; +} + +void ShellParser::init(String const& string) { + m_begin = string.begin(); + m_current = m_begin; + m_end = string.end(); + m_quotedType = '\0'; +} + +String ShellParser::word() { + String res; + + while (notDone()) { + auto letter = *current(); + bool escapedLetter = false; + + if (letter == '\\') { + escapedLetter = true; + letter = parseBackslash(); + } + + if (!escapedLetter) { + if (isSpace(letter) && !inQuotedString()) { + next(); + if (res.size()) { + return res; + } + continue; + } + + if (isQuote(letter)) { + if (inQuotedString() && letter == m_quotedType) { + m_quotedType = '\0'; + next(); + continue; + } + + if (!inQuotedString()) { + m_quotedType = letter; + next(); + continue; + } + } + } + + res.append(letter); + next(); + } + + return res; +} + +bool ShellParser::isSpace(Char letter) const { + return String::isSpace(letter); +} + +bool ShellParser::isQuote(Char letter) const { + return letter == '\'' || letter == '"'; +} + +bool ShellParser::inQuotedString() const { + return m_quotedType != '\0'; +} + +auto ShellParser::current() const -> Maybe<Char> { + if (m_current == m_end) { + return {}; + } + + return *m_current; +} + +auto ShellParser::next() -> Maybe<Char> { + if (m_current != m_end) { + ++m_current; + } + + return current(); +} + +auto ShellParser::previous() -> Maybe<Char> { + if (m_current != m_begin) { + --m_current; + } + + return current(); +} + +auto ShellParser::parseBackslash() -> Char { + auto letter = next(); + + if (!letter) { + return '\\'; + } + + switch (*letter) { + case ' ': + return ' '; + case 'n': + return '\n'; + case 't': + return '\t'; + case 'r': + return '\r'; + case 'b': + return '\b'; + case 'v': + return '\v'; + case 'f': + return '\f'; + case 'a': + return '\a'; + case '\'': + return '\''; + case '"': + return '"'; + case '\\': + return '\\'; + case '0': + return '\0'; + case 'u': { + auto letter = parseUnicodeEscapeSequence(); + if (isUtf16LeadSurrogate(letter)) { + auto shouldBeSlash = next(); + if (shouldBeSlash && shouldBeSlash == '\\') { + auto shouldBeU = next(); + if (shouldBeU && shouldBeU == 'u') { + return parseUnicodeEscapeSequence(letter); + } else { + previous(); + } + } + previous(); + return STAR_UTF32_REPLACEMENT_CHAR; + } else { + return letter; + } + } + default: + return *letter; + } +} + +auto ShellParser::parseUnicodeEscapeSequence(Maybe<Char> previousCodepoint) -> Char { + String codepoint; + + auto letter = current(); + + while (!isSpace(*letter) && codepoint.size() < 4) { + auto letter = next(); + if (!letter) { + break; + } + + if (!isxdigit(*letter)) { + return STAR_UTF32_REPLACEMENT_CHAR; + } + + codepoint.append(*letter); + } + + if (!codepoint.size()) { + return 'u'; + } + + if (codepoint.size() != 4) // exactly 4 digits are required by \u + return STAR_UTF32_REPLACEMENT_CHAR; + + try { + return hexStringToUtf32(codepoint.utf8(), previousCodepoint); + } catch (UnicodeException const&) { + return STAR_UTF32_REPLACEMENT_CHAR; + } +} + +bool ShellParser::notDone() const { + return m_current != m_end; +} + +} diff --git a/source/core/StarShellParser.hpp b/source/core/StarShellParser.hpp new file mode 100644 index 0000000..8e89273 --- /dev/null +++ b/source/core/StarShellParser.hpp @@ -0,0 +1,66 @@ +#ifndef STAR_SHELL_PARSER_HPP +#define STAR_SHELL_PARSER_HPP + +#include "StarString.hpp" +#include "StarEncode.hpp" +#include "StarBytes.hpp" +#include "StarFormat.hpp" + +namespace Star { + +// Currently the specification of the "language" is incredibly simple The only +// thing we process are quoted strings and backslashes Backslashes function as +// a useful subset of C++ This means: Newline: \n Tab: \t Backslash: \\ Single +// Quote: \' Double Quote: \" Null: \0 Space: "\ " (without quotes ofc, not +// actually C++) Also \v \b \a \f \r Plus Unicode \uxxxx Not implemented octal +// and hexadecimal, because it's possible to construct invalid unicode code +// points using them + +STAR_EXCEPTION(ShellParsingException, StarException); + +class ShellParser { +public: + ShellParser(); + typedef String::Char Char; + + enum class TokenType { + Word, + // TODO: braces, brackets, actual shell stuff + + }; + + struct Token { + TokenType type; + String token; + }; + + List<Token> tokenize(String const& command); + StringList tokenizeToStringList(String const& command); + +private: + void init(String const& command); + + String word(); + Char parseBackslash(); + Char parseUnicodeEscapeSequence(Maybe<Char> previousCodepoint = {}); + + bool isSpace(Char letter) const; + bool isQuote(Char letter) const; + + bool inQuotedString() const; + bool notDone() const; + + Maybe<Char> current() const; + Maybe<Char> next(); + Maybe<Char> previous(); + + String::const_iterator m_begin; + String::const_iterator m_current; + String::const_iterator m_end; + + Char m_quotedType; +}; + +} + +#endif diff --git a/source/core/StarSignalHandler.hpp b/source/core/StarSignalHandler.hpp new file mode 100644 index 0000000..0b70d71 --- /dev/null +++ b/source/core/StarSignalHandler.hpp @@ -0,0 +1,37 @@ +#ifndef STAR_SIGNAL_HANDLER_HPP +#define STAR_SIGNAL_HANDLER_HPP + +#include "StarException.hpp" + +namespace Star { + +STAR_STRUCT(SignalHandlerImpl); + +// Singleton signal handler that registers handlers for segfault, fpe, +// illegal instructions etc as well as non-fatal interrupts. +class SignalHandler { +public: + SignalHandler(); + ~SignalHandler(); + + // If enabled, will catch segfault, fpe, and illegal instructions and output + // error information before dying. + void setHandleFatal(bool handleFatal); + bool handlingFatal() const; + + // If enabled, non-fatal interrupt signal will be caught and will not kill + // the process and will instead set the interrupted flag. + void setHandleInterrupt(bool handleInterrupt); + bool handlingInterrupt() const; + + bool interruptCaught() const; + +private: + friend SignalHandlerImpl; + + static SignalHandlerImplUPtr s_singleton; +}; + +} + +#endif diff --git a/source/core/StarSignalHandler_unix.cpp b/source/core/StarSignalHandler_unix.cpp new file mode 100644 index 0000000..df8114a --- /dev/null +++ b/source/core/StarSignalHandler_unix.cpp @@ -0,0 +1,92 @@ +#include "StarSignalHandler.hpp" + +#include <signal.h> + +namespace Star { + +struct SignalHandlerImpl { + bool handlingFatal; + bool handlingInterrupt; + bool interrupted; + + SignalHandlerImpl() : handlingFatal(false), handlingInterrupt(false), interrupted(false) {} + + ~SignalHandlerImpl() { + setHandleFatal(false); + setHandleInterrupt(false); + } + + void setHandleFatal(bool b) { + handlingFatal = b; + if (handlingFatal) { + signal(SIGSEGV, handleFatal); + signal(SIGILL, handleFatal); + signal(SIGFPE, handleFatal); + signal(SIGBUS, handleFatal); + } else { + signal(SIGSEGV, SIG_DFL); + signal(SIGILL, SIG_DFL); + signal(SIGFPE, SIG_DFL); + signal(SIGBUS, SIG_DFL); + } + } + + void setHandleInterrupt(bool b) { + handlingInterrupt = b; + if (handlingInterrupt) + signal(SIGINT, handleInterrupt); + else + signal(SIGINT, SIG_DFL); + } + + static void handleFatal(int signum) { + if (signum == SIGSEGV) + fatalError("Segfault Encountered!", true); + else if (signum == SIGILL) + fatalError("Illegal Instruction Encountered!", true); + else if (signum == SIGFPE) + fatalError("Floating Point Exception Encountered!", true); + else if (signum == SIGBUS) + fatalError("Bus Error Encountered!", true); + } + + static void handleInterrupt(int) { + if (SignalHandler::s_singleton) + SignalHandler::s_singleton->interrupted = true; + } +}; + +SignalHandlerImplUPtr SignalHandler::s_singleton; + +SignalHandler::SignalHandler() { + if (s_singleton) + throw StarException("Singleton SignalHandler has been constructed twice!"); + + s_singleton = make_unique<SignalHandlerImpl>(); +} + +SignalHandler::~SignalHandler() { + s_singleton.reset(); +} + +void SignalHandler::setHandleFatal(bool handleFatal) { + s_singleton->setHandleFatal(handleFatal); +} + +bool SignalHandler::handlingFatal() const { + return s_singleton->handlingFatal; +} + +void SignalHandler::setHandleInterrupt(bool handleInterrupt) { + s_singleton->setHandleInterrupt(handleInterrupt); +} + +bool SignalHandler::handlingInterrupt() const { + return s_singleton->handlingInterrupt; +} + +bool SignalHandler::interruptCaught() const { + return s_singleton->interrupted; +} + +} diff --git a/source/core/StarSignalHandler_windows.cpp b/source/core/StarSignalHandler_windows.cpp new file mode 100644 index 0000000..0445aba --- /dev/null +++ b/source/core/StarSignalHandler_windows.cpp @@ -0,0 +1,190 @@ +#include "StarSignalHandler.hpp" +#include "StarFormat.hpp" +#include "StarString.hpp" +#include "StarLogging.hpp" + +#include <windows.h> + +namespace Star { + +String g_sehMessage; + +struct SignalHandlerImpl { + bool handlingFatal; + bool handlingInterrupt; + bool interrupted; + + PVOID handler; + + SignalHandlerImpl() : handlingFatal(false), handlingInterrupt(false), interrupted(false) {} + + ~SignalHandlerImpl() { + setHandleFatal(false); + setHandleInterrupt(false); + } + + void setHandleFatal(bool b) { + handlingFatal = b; + + if (handler) { + RemoveVectoredExceptionHandler(handler); + handler = nullptr; + } + + if (handlingFatal) + handler = AddVectoredExceptionHandler(1, vectoredExceptionHandler); + } + + void setHandleInterrupt(bool b) { + handlingInterrupt = b; + + SetConsoleCtrlHandler(nullptr, false); + + if (handlingInterrupt) + SetConsoleCtrlHandler((PHANDLER_ROUTINE)consoleCtrlHandler, true); + } + + static void sehTrampoline() { + fatalError(g_sehMessage.utf8Ptr(), true); + } + + static void handleFatalError(String const& msg, PEXCEPTION_POINTERS ExceptionInfo) { + if (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) { + String mode; + DWORD modeFlag = ExceptionInfo->ExceptionRecord->ExceptionInformation[0]; + if (modeFlag == 0) + mode = "Read"; + else if (modeFlag == 1) + mode = "Write"; + else if (modeFlag == 8) + mode = "Execute"; + else + mode = strf("Mode(%s)", modeFlag); + g_sehMessage = strf("Access violation detected at %s (%s of address %s)", + ExceptionInfo->ExceptionRecord->ExceptionAddress, + mode, + (PVOID)ExceptionInfo->ExceptionRecord->ExceptionInformation[1]); + } else { + g_sehMessage = msg; + g_sehMessage = strf("%s (%p @ %s)", + g_sehMessage, + ExceptionInfo->ExceptionRecord->ExceptionCode, + ExceptionInfo->ExceptionRecord->ExceptionAddress); + for (DWORD i = 0; i < ExceptionInfo->ExceptionRecord->NumberParameters; i++) + g_sehMessage = strf("%s [%s]", g_sehMessage, (PVOID)ExceptionInfo->ExceptionRecord->ExceptionInformation[i]); + } + +// setup a hijack into our own trampoline as if the failure actually was a +// function call +#ifdef STAR_ARCHITECTURE_X86_64 + DWORD64 rsp = ExceptionInfo->ContextRecord->Rsp - 8; + DWORD64 rip = ExceptionInfo->ContextRecord->Rip; // an offset avoid the issue of gdb thinking + // the error is one statement too early, but + // the offset is instruction dependent, and we + // don't know its size + 1; + *((DWORD64*)rsp) = rip; + ExceptionInfo->ContextRecord->Rsp = rsp; + ExceptionInfo->ContextRecord->Rip = (DWORD64)&sehTrampoline; +#else + DWORD esp = ExceptionInfo->ContextRecord->Esp - 4; + DWORD eip = ExceptionInfo->ContextRecord->Eip; // an offset avoid the issue of gdb thinking the + // error is one statement too early, but the + // offset is instruction dependent, and we don't + // know its size + 1; + *((DWORD*)esp) = eip; + ExceptionInfo->ContextRecord->Esp = esp; + ExceptionInfo->ContextRecord->Eip = (DWORD)&sehTrampoline; +#endif + } + + static LONG CALLBACK vectoredExceptionHandler(PEXCEPTION_POINTERS ExceptionInfo) { + if (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_STACK_OVERFLOW) { + fatalError("Stack overflow encountered", false); + } + if (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) { + handleFatalError("Access violation detected", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + if ((ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_ILLEGAL_INSTRUCTION) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_PRIV_INSTRUCTION)) { + handleFatalError("Illegal instruction encountered", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + + if ((ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_DENORMAL_OPERAND) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_DIVIDE_BY_ZERO) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_INEXACT_RESULT) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_INVALID_OPERATION) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_OVERFLOW) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_STACK_CHECK) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_FLT_UNDERFLOW) + + ) { + handleFatalError("Floating point exception", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + + if (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_INT_DIVIDE_BY_ZERO) { + handleFatalError("Division by zero", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + + if (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_INT_OVERFLOW) { + handleFatalError("Integer overflow", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + + if ((ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_DATATYPE_MISALIGNMENT) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_ARRAY_BOUNDS_EXCEEDED) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_IN_PAGE_ERROR) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_NONCONTINUABLE_EXCEPTION) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_INVALID_DISPOSITION) + || (ExceptionInfo->ExceptionRecord->ExceptionCode == EXCEPTION_INVALID_HANDLE)) { + handleFatalError("Error occured", ExceptionInfo); + return EXCEPTION_CONTINUE_EXECUTION; + } + + return EXCEPTION_CONTINUE_SEARCH; + } + + static BOOL WINAPI consoleCtrlHandler(DWORD) { + if (SignalHandler::s_singleton) + SignalHandler::s_singleton->interrupted = true; + return true; + } +}; + +SignalHandlerImplUPtr SignalHandler::s_singleton; + +SignalHandler::SignalHandler() { + if (s_singleton) + throw StarException("Singleton SignalHandler has been constructed twice!"); + + s_singleton = make_unique<SignalHandlerImpl>(); +} + +SignalHandler::~SignalHandler() { + s_singleton.reset(); +} + +void SignalHandler::setHandleFatal(bool handleFatal) { + s_singleton->setHandleFatal(handleFatal); +} + +bool SignalHandler::handlingFatal() const { + return s_singleton->handlingFatal; +} + +void SignalHandler::setHandleInterrupt(bool handleInterrupt) { + s_singleton->setHandleInterrupt(handleInterrupt); +} + +bool SignalHandler::handlingInterrupt() const { + return s_singleton->handlingInterrupt; +} + +bool SignalHandler::interruptCaught() const { + return s_singleton->interrupted; +} + +} diff --git a/source/core/StarSmallVector.hpp b/source/core/StarSmallVector.hpp new file mode 100644 index 0000000..38d32a3 --- /dev/null +++ b/source/core/StarSmallVector.hpp @@ -0,0 +1,447 @@ +#ifndef STAR_SMALL_VECTOR_HPP +#define STAR_SMALL_VECTOR_HPP + +#include "StarAlgorithm.hpp" + +namespace Star { + +// A vector that is stack allocated up to a maximum size, becoming heap +// allocated when it grows beyond that size. Always takes up stack space of +// MaxStackSize * sizeof(Element). +template <typename Element, size_t MaxStackSize> +class SmallVector { +public: + typedef Element* iterator; + typedef Element const* const_iterator; + + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + + typedef Element value_type; + + typedef Element& reference; + typedef Element const& const_reference; + + SmallVector(); + SmallVector(SmallVector const& other); + SmallVector(SmallVector&& other); + template <typename OtherElement, size_t OtherMaxStackSize> + SmallVector(SmallVector<OtherElement, OtherMaxStackSize> const& other); + template <class Iterator> + SmallVector(Iterator first, Iterator last); + SmallVector(size_t size, Element const& value = Element()); + SmallVector(initializer_list<Element> list); + ~SmallVector(); + + SmallVector& operator=(SmallVector const& other); + SmallVector& operator=(SmallVector&& other); + SmallVector& operator=(std::initializer_list<Element> list); + + size_t size() const; + bool empty() const; + void resize(size_t size, Element const& e = Element()); + void reserve(size_t capacity); + + reference at(size_t i); + const_reference at(size_t i) const; + + reference operator[](size_t i); + const_reference operator[](size_t i) const; + + const_iterator begin() const; + const_iterator end() const; + + iterator begin(); + iterator end(); + + const_reverse_iterator rbegin() const; + const_reverse_iterator rend() const; + + reverse_iterator rbegin(); + reverse_iterator rend(); + + // Pointer to internal data, always valid even if empty. + Element const* ptr() const; + Element* ptr(); + + void push_back(Element e); + void pop_back(); + + iterator insert(iterator pos, Element e); + template <typename Iterator> + iterator insert(iterator pos, Iterator begin, Iterator end); + iterator insert(iterator pos, initializer_list<Element> list); + + template <typename... Args> + void emplace(iterator pos, Args&&... args); + + template <typename... Args> + void emplace_back(Args&&... args); + + void clear(); + + iterator erase(iterator pos); + iterator erase(iterator begin, iterator end); + + bool operator==(SmallVector const& other) const; + bool operator!=(SmallVector const& other) const; + bool operator<(SmallVector const& other) const; + +private: + typename std::aligned_storage<MaxStackSize * sizeof(Element), alignof(Element)>::type m_stackElements; + + bool isHeapAllocated() const; + + Element* m_begin; + Element* m_end; + Element* m_capacity; +}; + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector() { + m_begin = (Element*)&m_stackElements; + m_end = m_begin; + m_capacity = m_begin + MaxStackSize; +} + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::~SmallVector() { + clear(); + if (isHeapAllocated()) { + free(m_begin, (m_capacity - m_begin) * sizeof(Element)); + } +} + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector(SmallVector const& other) + : SmallVector() { + insert(begin(), other.begin(), other.end()); +} + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector(SmallVector&& other) + : SmallVector() { + for (auto& e : other) + emplace_back(move(e)); +} + +template <typename Element, size_t MaxStackSize> +template <typename OtherElement, size_t OtherMaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector(SmallVector<OtherElement, OtherMaxStackSize> const& other) + : SmallVector() { + for (auto const& e : other) + emplace_back(e); +} + +template <typename Element, size_t MaxStackSize> +template <class Iterator> +SmallVector<Element, MaxStackSize>::SmallVector(Iterator first, Iterator last) + : SmallVector() { + insert(begin(), first, last); +} + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector(size_t size, Element const& value) + : SmallVector() { + resize(size, value); +} + +template <typename Element, size_t MaxStackSize> +SmallVector<Element, MaxStackSize>::SmallVector(initializer_list<Element> list) + : SmallVector() { + for (auto const& e : list) + emplace_back(e); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::operator=(SmallVector const& other) -> SmallVector& { + if (this == &other) + return *this; + + resize(other.size()); + for (size_t i = 0; i < size(); ++i) + operator[](i) = other[i]; + + return *this; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::operator=(SmallVector&& other) -> SmallVector& { + resize(other.size()); + for (size_t i = 0; i < size(); ++i) + operator[](i) = move(other[i]); + + return *this; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::operator=(std::initializer_list<Element> list) -> SmallVector& { + resize(list.size()); + for (size_t i = 0; i < size(); ++i) + operator[](i) = move(list[i]); + return *this; +} + +template <typename Element, size_t MaxStackSize> +size_t SmallVector<Element, MaxStackSize>::size() const { + return m_end - m_begin; +} + +template <typename Element, size_t MaxStackSize> +bool SmallVector<Element, MaxStackSize>::empty() const { + return m_begin == m_end; +} + +template <typename Element, size_t MaxStackSize> +void SmallVector<Element, MaxStackSize>::resize(size_t size, Element const& e) { + reserve(size); + + for (size_t i = this->size(); i > size; --i) + pop_back(); + for (size_t i = this->size(); i < size; ++i) + emplace_back(e); +} + +template <typename Element, size_t MaxStackSize> +void SmallVector<Element, MaxStackSize>::reserve(size_t newCapacity) { + size_t oldCapacity = m_capacity - m_begin; + if (newCapacity > oldCapacity) { + newCapacity = max(oldCapacity * 2, newCapacity); + auto newMem = (Element*)Star::malloc(newCapacity * sizeof(Element)); + if (!newMem) + throw MemoryException::format("Could not set new SmallVector capacity %s\n", newCapacity); + + size_t size = m_end - m_begin; + auto oldMem = m_begin; + auto oldHeapAllocated = isHeapAllocated(); + + // We assume that move constructors can never throw. + for (size_t i = 0; i < size; ++i) { + new (&newMem[i]) Element(move(oldMem[i])); + } + + m_begin = newMem; + m_end = m_begin + size; + m_capacity = m_begin + newCapacity; + + auto freeOldMem = finally([=]() { + if (oldHeapAllocated) + Star::free(oldMem, oldCapacity * sizeof(Element)); + }); + + for (size_t i = 0; i < size; ++i) { + oldMem[i].~Element(); + } + } +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::at(size_t i) -> reference { + if (i >= size()) + throw OutOfRangeException::format("out of range in SmallVector::at(%s)", i); + return m_begin[i]; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::at(size_t i) const -> const_reference { + if (i >= size()) + throw OutOfRangeException::format("out of range in SmallVector::at(%s)", i); + return m_begin[i]; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::operator[](size_t i) -> reference { + starAssert(i < size()); + return m_begin[i]; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::operator[](size_t i) const -> const_reference { + starAssert(i < size()); + return m_begin[i]; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::begin() const -> const_iterator { + return m_begin; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::end() const -> const_iterator { + return m_end; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::begin() -> iterator { + return m_begin; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::end() -> iterator { + return m_end; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::rbegin() const -> const_reverse_iterator { + return const_reverse_iterator(end()); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::rend() const -> const_reverse_iterator { + return const_reverse_iterator(begin()); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::rbegin() -> reverse_iterator { + return reverse_iterator(end()); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::rend() -> reverse_iterator { + return reverse_iterator(begin()); +} + +template <typename Element, size_t MaxStackSize> +Element const* SmallVector<Element, MaxStackSize>::ptr() const { + return m_begin; +} + +template <typename Element, size_t MaxStackSize> +Element* SmallVector<Element, MaxStackSize>::ptr() { + return m_begin; +} + +template <typename Element, size_t MaxStackSize> +void SmallVector<Element, MaxStackSize>::push_back(Element e) { + emplace_back(move(e)); +} + +template <typename Element, size_t MaxStackSize> +void SmallVector<Element, MaxStackSize>::pop_back() { + if (m_begin == m_end) + throw OutOfRangeException("SmallVector::pop_back called on empty SmallVector"); + --m_end; + m_end->~Element(); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::insert(iterator pos, Element e) -> iterator { + emplace(pos, move(e)); + return pos; +} + +template <typename Element, size_t MaxStackSize> +template <typename Iterator> +auto SmallVector<Element, MaxStackSize>::insert(iterator pos, Iterator begin, Iterator end) -> iterator { + size_t toAdd = std::distance(begin, end); + size_t startIndex = pos - m_begin; + size_t endIndex = startIndex + toAdd; + size_t toShift = size() - startIndex; + + resize(size() + toAdd); + + for (size_t i = toShift; i != 0; --i) + operator[](endIndex + i - 1) = move(operator[](startIndex + i - 1)); + + for (size_t i = 0; i != toAdd; ++i) + operator[](startIndex + i) = *begin++; + + return pos; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::insert(iterator pos, initializer_list<Element> list) -> iterator { + return insert(pos, list.begin(), list.end()); +} + +template <typename Element, size_t MaxStackSize> +template <typename... Args> +void SmallVector<Element, MaxStackSize>::emplace(iterator pos, Args&&... args) { + size_t index = pos - m_begin; + emplace_back(Element()); + for (size_t i = size() - 1; i != index; --i) + operator[](i) = move(operator[](i - 1)); + operator[](index) = Element(forward<Args>(args)...); +} + +template <typename Element, size_t MaxStackSize> +template <typename... Args> +void SmallVector<Element, MaxStackSize>::emplace_back(Args&&... args) { + if (m_end == m_capacity) + reserve(size() + 1); + new (m_end) Element(forward<Args>(args)...); + ++m_end; +} + +template <typename Element, size_t MaxStackSize> +void SmallVector<Element, MaxStackSize>::clear() { + while (m_begin != m_end) + pop_back(); +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::erase(iterator pos) -> iterator { + size_t index = pos - ptr(); + for (size_t i = index; i < size() - 1; ++i) + operator[](i) = move(operator[](i + 1)); + pop_back(); + return pos; +} + +template <typename Element, size_t MaxStackSize> +auto SmallVector<Element, MaxStackSize>::erase(iterator begin, iterator end) -> iterator { + size_t startIndex = begin - ptr(); + size_t endIndex = end - ptr(); + size_t toRemove = endIndex - startIndex; + for (size_t i = endIndex; i < size(); ++i) + operator[](startIndex + (i - endIndex)) = move(operator[](i)); + resize(size() - toRemove); + return begin; +} + +template <typename Element, size_t MaxStackSize> +bool SmallVector<Element, MaxStackSize>::operator==(SmallVector const& other) const { + if (this == &other) + return true; + + if (size() != other.size()) + return false; + + for (size_t i = 0; i < size(); ++i) { + if (operator[](i) != other[i]) + return false; + } + return true; +} + +template <typename Element, size_t MaxStackSize> +bool SmallVector<Element, MaxStackSize>::operator!=(SmallVector const& other) const { + return !operator==(other); +} + +template <typename Element, size_t MaxStackSize> +bool SmallVector<Element, MaxStackSize>::operator<(SmallVector const& other) const { + for (size_t i = 0; i < size(); ++i) { + if (i >= other.size()) + return false; + + Element const& a = operator[](i); + Element const& b = other[i]; + + if (a < b) + return true; + else if (b < a) + return false; + } + + return size() < other.size(); +} + +template <typename Element, size_t MaxStackSize> +bool SmallVector<Element, MaxStackSize>::isHeapAllocated() const { + return m_begin != (Element*)&m_stackElements; +} + +} + +#endif diff --git a/source/core/StarSocket.cpp b/source/core/StarSocket.cpp new file mode 100644 index 0000000..95ffe7c --- /dev/null +++ b/source/core/StarSocket.cpp @@ -0,0 +1,272 @@ +#include "StarSocket.hpp" +#include "StarLogging.hpp" +#include "StarNetImpl.hpp" + +namespace Star { + +Maybe<SocketPollResult> Socket::poll(SocketPollQuery const& query, unsigned timeout) { + if (query.empty()) + return {}; + + // Prevent close from being called on any socket during this call. + LinkedList<ReadLocker> readLockers; + for (auto const& p : query) + readLockers.emplaceAppend(p.first->m_mutex); + + // If any sockets are already closed, then this is an "event" according to + // this api but we cannot call poll on a closed socket, so just poll the rest + // of the sockets with no wait. + SocketPollResult result; + for (auto const& p : query) { + if (!p.first->isOpen()) { + result[p.first].exception = true; + timeout = 0; + } + } + +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + fd_set readfs; + fd_set writefs; + fd_set exceptfs; + + FD_ZERO(&readfs); + FD_ZERO(&writefs); + FD_ZERO(&exceptfs); + + int ret; + for (auto const& p : query) { + if (p.first->isOpen()) { + if (p.second.readable) + FD_SET(p.first->m_impl->socketDesc, &readfs); + if (p.second.writable) + FD_SET(p.first->m_impl->socketDesc, &writefs); + FD_SET(p.first->m_impl->socketDesc, &exceptfs); + } + } + timeval time; + time.tv_usec = (timeout % 1000) * 1000; + time.tv_sec = timeout - timeout % 1000; + ret = ::select(0, &readfs, &writefs, &exceptfs, &time); + + if (ret < 0) + throw NetworkException::format("Error during call to select, '%s'", netErrorString()); + + if (ret == 0) + return {}; + + for (auto const& p : query) { + if (p.first->isOpen()) { + auto& r = result[p.first]; + r.readable = FD_ISSET(p.first->m_impl->socketDesc, &readfs); + r.writable = FD_ISSET(p.first->m_impl->socketDesc, &writefs); + r.exception = FD_ISSET(p.first->m_impl->socketDesc, &exceptfs); + if (r.exception) + p.first->doShutdown(); + } + } + +#else + unique_ptr<pollfd[]> pollfds(new pollfd[query.size()]); + int ret = 0; + for (auto p : enumerateIterator(query)) { + if (p.first.first->isOpen()) { + auto& pfd = pollfds[p.second]; + pfd.fd = p.first.first->m_impl->socketDesc; + pfd.events = 0; + if (p.first.second.readable) + pfd.events |= POLLIN; + if (p.first.second.writable) + pfd.events |= POLLOUT; + } + } + ret = ::poll(pollfds.get(), query.size(), timeout); + + if (ret < 0) + throw NetworkException::format("Error during call to poll, '%s'", netErrorString()); + + if (ret == 0) + return {}; + + for (auto p : enumerateIterator(query)) { + if (p.first.first->isOpen()) { + auto& pfd = pollfds[p.second]; + SocketPollResultEntry pr; + pr.readable = pfd.revents & POLLIN; + pr.writable = pfd.revents & POLLOUT; + pr.exception = pfd.revents & POLLHUP || pfd.revents & POLLNVAL || pfd.revents & POLLERR; + if (pfd.revents & POLLHUP) + p.first.first->doShutdown(); + result.add(p.first.first, move(pr)); + } + } +#endif + + readLockers.clear(); + + return result; +} + +Socket::~Socket() { + close(); +} + +void Socket::bind(HostAddressWithPort const& addressWithPort) { + WriteLocker locker(m_mutex); + checkOpen("Socket::bind"); + + struct sockaddr_storage sockAddr; + socklen_t sockAddrLen; + + if (addressWithPort.address().mode() != m_networkMode) + throw NetworkException("Bind address does not match socket mode"); + + // Ensure quick restarts don't prevent us binding + int set = 1; + m_impl->setSockOpt(SOL_SOCKET, SO_REUSEADDR, (void*)&set, sizeof(int)); + + m_localAddress = addressWithPort; + setNativeFromAddress(m_localAddress, &sockAddr, &sockAddrLen); + if (::bind(m_impl->socketDesc, (struct sockaddr*)&sockAddr, sockAddrLen) < 0) + throw NetworkException(strf("Cannot bind socket to %s: %s", m_localAddress, netErrorString())); + + m_socketMode = SocketMode::Bound; + + Logger::debug("bind %s (%d)", addressWithPort, m_impl->socketDesc); +} + +void Socket::listen(int backlog) { + WriteLocker locker(m_mutex); + + if (::listen(m_impl->socketDesc, backlog) != 0) + throw NetworkException(strf("Could not listen on socket: '%s'", netErrorString())); +} + +void Socket::setTimeout(unsigned timeout) { + ReadLocker locker(m_mutex); + checkOpen("Socket::setTimeout"); + + void* val; + socklen_t size; +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + val = &timeout; + size = sizeof(timeout); +#else + struct timeval tv; + tv.tv_sec = timeout - timeout % 1000; + tv.tv_usec = (timeout % 1000) * 1000; + val = &tv; + size = sizeof(tv); +#endif + + m_impl->setSockOpt(SOL_SOCKET, SO_RCVTIMEO, val, size); + m_impl->setSockOpt(SOL_SOCKET, SO_SNDTIMEO, val, size); +} + +void Socket::setNonBlocking(bool nonBlocking) { + ReadLocker locker(m_mutex); + checkOpen("Socket::setNonBlocking"); +#ifdef WIN32 + unsigned long mode = nonBlocking ? 1 : 0; + if (ioctlsocket(m_impl->socketDesc, FIONBIO, &mode) != 0) + throw NetworkException::format("Cannot set socket non-blocking mode: %s", netErrorString()); +#else + int flags = fcntl(m_impl->socketDesc, F_GETFL, 0); + if (flags < 0) + throw NetworkException::format("fcntl failure getting socket flags: %s", netErrorString()); + flags = nonBlocking ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK); + if (fcntl(m_impl->socketDesc, F_SETFL, flags) != 0) + throw NetworkException::format("fcntl failure setting non-blocking mode: %s", netErrorString()); +#endif +} + +NetworkMode Socket::networkMode() const { + ReadLocker locker(m_mutex); + return m_networkMode; +} + +SocketMode Socket::socketMode() const { + ReadLocker locker(m_mutex); + return m_socketMode; +} + +bool Socket::isActive() const { + return m_socketMode > SocketMode::Shutdown; +} + +bool Socket::isOpen() const { + return m_socketMode != SocketMode::Closed; +} + +void Socket::shutdown() { + ReadLocker locker(m_mutex); + doShutdown(); +} + +void Socket::close() { + WriteLocker locker(m_mutex); + doShutdown(); + doClose(); +} + +Socket::Socket(SocketType type, NetworkMode networkMode) + : m_networkMode(networkMode), m_impl(make_shared<SocketImpl>()), m_socketMode(SocketMode::Closed) { + if (m_networkMode == NetworkMode::IPv4) + m_impl->socketDesc = ::socket(AF_INET, type == SocketType::Tcp ? SOCK_STREAM : SOCK_DGRAM, 0); + else + m_impl->socketDesc = ::socket(AF_INET6, type == SocketType::Tcp ? SOCK_STREAM : SOCK_DGRAM, 0); + + if (invalidSocketDescriptor(m_impl->socketDesc)) + throw NetworkException(strf("cannot create socket: %s", netErrorString())); + + m_socketMode = SocketMode::Shutdown; + setTimeout(60000); + setNonBlocking(false); +} + +Socket::Socket(NetworkMode networkMode, SocketImplPtr impl, SocketMode socketMode) + : m_networkMode(networkMode), m_impl(impl), m_socketMode(socketMode) { + setTimeout(60000); + setNonBlocking(false); +} + +void Socket::checkOpen(char const* methodName) const { + if (m_socketMode == SocketMode::Closed) + throw SocketClosedException::format("Socket not open in %s", methodName); +} + +void Socket::doShutdown() { + if (m_socketMode <= SocketMode::Shutdown) + return; + + // Set socket mode first so that if this causes an exception the exception + // handlers know the socket is being shut down. + m_socketMode = SocketMode::Shutdown; + + if (m_impl->socketDesc > 0) { +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + ::shutdown(m_impl->socketDesc, SD_BOTH); +#else + ::shutdown(m_impl->socketDesc, SHUT_RDWR); +#endif + } +} + +void Socket::doClose() { + if (m_socketMode == SocketMode::Closed) + return; + + // Set socket mode first so that if this causes an exception the exception + // handlers know the socket is being closed. + m_socketMode = SocketMode::Closed; + + if (m_impl->socketDesc > 0) { +#ifdef STAR_SYSTEM_FAMILY_WINDOWS + ::closesocket(m_impl->socketDesc); +#else + ::close(m_impl->socketDesc); +#endif + m_impl->socketDesc = 0; + } +} + +} diff --git a/source/core/StarSocket.hpp b/source/core/StarSocket.hpp new file mode 100644 index 0000000..1efe2c2 --- /dev/null +++ b/source/core/StarSocket.hpp @@ -0,0 +1,98 @@ +#ifndef STAR_SOCKET_HPP +#define STAR_SOCKET_HPP + +#include "StarHostAddress.hpp" +#include "StarThread.hpp" + +namespace Star { + +// Thrown when some call on a socket failed because the socket is *either* +// closed or shutdown, for other errors sockets will throw NetworkException +STAR_EXCEPTION(SocketClosedException, NetworkException); + +STAR_STRUCT(SocketImpl); +STAR_CLASS(Socket); + +enum class SocketMode { + Closed, + Shutdown, + Bound, + Connected +}; + +struct SocketPollQueryEntry { + // Query whether the tcp socket is readable + bool readable; + // Query whether the tcp socket is writable + bool writable; +}; + +struct SocketPollResultEntry { + // The tcp socket can be read without blocking + bool readable; + // The tcp socket can be written without blocking + bool writable; + // The tcp socket has had an error condition, or it has been closed. + bool exception; +}; + +typedef Map<SocketPtr, SocketPollQueryEntry> SocketPollQuery; +typedef Map<SocketPtr, SocketPollResultEntry> SocketPollResult; + +class Socket { +public: + // Waits for sockets that are readable, writiable, or have pending error + // conditions within the given timeout. Returns result if any sockets are + // ready for I/O or have had error events occur on them within the timeout, + // nothing otherwise. If socket hangup occurs during this call, this will + // automatically shut down the socket. + static Maybe<SocketPollResult> poll(SocketPollQuery const& query, unsigned timeout); + + ~Socket(); + + void bind(HostAddressWithPort const& address); + void listen(int backlog); + + // Sockets default to blocking mode + void setNonBlocking(bool nonBlocking); + // Sockets default to 60 second timeout + void setTimeout(unsigned millis); + + NetworkMode networkMode() const; + SocketMode socketMode() const; + + // Is the socketMode either Bound or Connected? + bool isActive() const; + + // Is the socketMode not closed? + bool isOpen() const; + + // Shuts down the underlying socket only. + void shutdown(); + + // Shuts down and closes the underlying socket. + void close(); + +protected: + enum class SocketType { + Tcp, + Udp + }; + + Socket(SocketType type, NetworkMode networkMode); + Socket(NetworkMode networkMode, SocketImplPtr impl, SocketMode socketMode); + + void checkOpen(char const* methodName) const; + void doShutdown(); + void doClose(); + + mutable ReadersWriterMutex m_mutex; + NetworkMode m_networkMode; + SocketImplPtr m_impl; + atomic<SocketMode> m_socketMode; + HostAddressWithPort m_localAddress; +}; + +} + +#endif diff --git a/source/core/StarSpatialHash2D.hpp b/source/core/StarSpatialHash2D.hpp new file mode 100644 index 0000000..2241dfd --- /dev/null +++ b/source/core/StarSpatialHash2D.hpp @@ -0,0 +1,337 @@ +#ifndef STAR_SPATIAL_HASH_2D_HPP +#define STAR_SPATIAL_HASH_2D_HPP + +#include "StarRect.hpp" +#include "StarMap.hpp" +#include "StarSet.hpp" +#include "StarBlockAllocator.hpp" + +namespace Star { + +// Dual-map based on key and 2 dimensional bounding rectangle. Implements a 2d +// spatial hash for fast bounding box queries. Each entry may have more than +// one bounding rectangle. +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT = int, size_t AllocatorBlockSize = 4096> +class SpatialHash2D { +public: + typedef KeyT Key; + typedef ScalarT Scalar; + typedef Box<ScalarT, 2> Rect; + typedef typename Rect::Coord Coord; + typedef ValueT Value; + + struct Entry { + Entry(); + + SmallList<Rect, 2> rects; + Value value; + }; + + typedef StableHashMap<Key, Entry, hash<Key>, std::equal_to<Key>, BlockAllocator<pair<Key const, Entry>, AllocatorBlockSize>> EntryMap; + + SpatialHash2D(Scalar const& sectorSize); + + List<Key> keys() const; + List<Value> values() const; + EntryMap const& entries() const; + + size_t size() const; + + bool contains(Key const& key) const; + + Value const& get(Key const& key) const; + Value& get(Key const& key); + + // Returns default constructed value if key not found + Value value(Key const& key) const; + + // Query values from several bounding boxes at once with no duplicates. + List<Value> queryValues(Rect const& rect) const; + template <typename RectCollection> + List<Value> queryValues(RectCollection const& rects) const; + + // Iterate over entries in the given bounding boxes without duplication. It + // is safe to modify rects or add entries from the given callback, but it is + // not safe to remove entries from it. + template <typename Function> + void forEach(Rect const& rect, Function&& function) const; + template <typename RectCollection, typename Function> + void forEach(RectCollection const& rects, Function&& function) const; + + void set(Key const& key, Coord const& pos); + void set(Key const& key, Rect const& rect); + + template <typename RectCollection> + void set(Key const& key, RectCollection const& rects); + + void set(Key const& key, Coord const& pos, Value value); + void set(Key const& key, Rect const& rect, Value value); + + template <typename RectCollection> + void set(Key const& key, RectCollection const& rects, Value value); + + Maybe<Value> remove(Key const& key); + + // Recalculates every item in sector map + void setSectorSize(Scalar const& sectorSize); + +private: + typedef Vector<IntT, 2> Sector; + typedef Box<IntT, 2> SectorRange; + typedef HashSet<Entry const*, hash<Entry const*>, std::equal_to<Entry const*>> SectorEntrySet; + typedef HashMap<Sector, SectorEntrySet> SectorMap; + + SectorRange getSectors(Rect const& r) const; + + void addSpatial(Entry const* entry); + void removeSpatial(Entry const* entry); + + template <typename RectCollection> + void updateSpatial(Entry* entry, RectCollection const& rects); + + Scalar m_sectorSize; + EntryMap m_entryMap; + SectorMap m_sectorMap; +}; + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::Entry::Entry() + : value() {} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::SpatialHash2D(Scalar const& sectorSize) + : m_sectorSize(sectorSize) {} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +List<KeyT> SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::keys() const { + return m_entryMap.keys(); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +List<typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::Value> SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::values() const { + List<Value> values; + for (auto const& pair : m_entryMap) + values.append(pair.second.value); + + return values; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::EntryMap const& +SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::entries() const { + return m_entryMap; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +size_t SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::size() const { + return m_entryMap.size(); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +bool SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::contains(Key const& key) const { + return m_entryMap.contains(key); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::Value const& SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::get( + Key const& key) const { + return m_entryMap.get(key).value; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::Value& SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::get( + Key const& key) { + return m_entryMap.get(key).value; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::Value SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::value( + Key const& key) const { + auto iter = m_entryMap.find(key); + if (iter == m_entryMap.end()) + return Value(); + else + return iter->second.value; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +List<ValueT> SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::queryValues(Rect const& rect) const { + return queryValues(initializer_list<Rect>{rect}); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename RectCollection> +List<ValueT> SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::queryValues(RectCollection const& rects) const { + List<Value> values; + forEach(rects, [&values](Value const& value) { + values.append(value); + }); + return values; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename Function> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::forEach(Rect const& rect, Function&& function) const { + return forEach(initializer_list<Rect>{rect}, forward<Function>(function)); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename RectCollection, typename Function> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::forEach(RectCollection const& rects, Function&& function) const { + SmallList<Entry const*, 32> foundEntries; + + for (Rect const& rect : rects) { + if (rect.isNull()) + continue; + + auto sectorResult = getSectors(rect); + + for (IntT x = sectorResult.xMin(); x < sectorResult.xMax(); ++x) { + for (IntT y = sectorResult.yMin(); y < sectorResult.yMax(); ++y) { + auto i = m_sectorMap.find(Sector{x, y}); + if (i != m_sectorMap.end()) { + for (auto e : i->second) { + for (Rect const& r : e->rects) { + if (r.intersects(rect)) { + foundEntries.append(e); + break; + } + } + } + } + } + } + } + + // Rather than keep a Set of keys to avoid duplication in found entries, it + // is much faster to simply keep all encountered intersected entries and then + // sort them later for all but the most massive and most populated searches, + // due to the allocation cost of Set and HashSet. + sort(foundEntries); + + // Looping over the found entries in sorted order with potential duplication, + // so need to skip over the entry if the previous entry is the same as the + // current entry + Entry const* prev = nullptr; + for (auto const& entry : foundEntries) { + if (entry == prev) + continue; + prev = entry; + function(entry->value); + } +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, Coord const& pos) { + set(key, {Rect(pos, pos)}); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, Rect const& rect) { + set(key, {rect}); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename RectCollection> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, RectCollection const& rects) { + updateSpatial(&m_entryMap.get(key), rects); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, Coord const& pos, Value value) { + set(key, {Rect(pos, pos)}, move(value)); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, Rect const& rect, Value value) { + set(key, {rect}, move(value)); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename RectCollection> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::set(Key const& key, RectCollection const& rects, Value value) { + Entry& entry = m_entryMap[key]; + entry.value = move(value); + updateSpatial(&entry, rects); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +auto SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::remove(Key const& key) -> Maybe<Value> { + auto iter = m_entryMap.find(key); + if (iter == m_entryMap.end()) + return {}; + + removeSpatial(&iter->second); + Maybe<Value> val = move(iter->second.value); + m_entryMap.erase(iter); + return val; +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::setSectorSize(Scalar const& sectorSize) { + m_sectorSize = sectorSize; + m_sectorMap.clear(); + for (auto const& pair : m_entryMap) + addSpatial(pair.first, pair.second); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +typename SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::SectorRange SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::getSectors(Rect const& r) const { + return SectorRange( + floor(r.xMin() / m_sectorSize), + floor(r.yMin() / m_sectorSize), + ceil(r.xMax() / m_sectorSize), + ceil(r.yMax() / m_sectorSize)); +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::addSpatial(Entry const* entry) { + for (Rect const& rect : entry->rects) { + if (rect.isNull()) + continue; + + auto sectorResult = getSectors(rect); + for (IntT x = sectorResult.xMin(); x < sectorResult.xMax(); ++x) { + for (IntT y = sectorResult.yMin(); y < sectorResult.yMax(); ++y) { + Sector sector(x, y); + SectorEntrySet* p = m_sectorMap.ptr(sector); + if (!p) + p = &m_sectorMap.add(sector, SectorEntrySet()); + p->add(entry); + } + } + } +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::removeSpatial(Entry const* entry) { + for (Rect const& rect : entry->rects) { + if (rect.isNull()) + continue; + + auto sectorResult = getSectors(rect); + for (IntT x = sectorResult.xMin(); x < sectorResult.xMax(); ++x) { + for (IntT y = sectorResult.yMin(); y < sectorResult.yMax(); ++y) { + auto i = m_sectorMap.find(Sector{x, y}); + if (i != m_sectorMap.end()) { + i->second.remove(entry); + if (i->second.empty()) + m_sectorMap.erase(i); + } + } + } + } +} + +template <typename KeyT, typename ScalarT, typename ValueT, typename IntT, size_t AllocatorBlockSize> +template <typename RectCollection> +void SpatialHash2D<KeyT, ScalarT, ValueT, IntT, AllocatorBlockSize>::updateSpatial(Entry* entry, RectCollection const& rects) { + removeSpatial(entry); + entry->rects.clear(); + entry->rects.appendAll(rects); + addSpatial(entry); +} + +} + +#endif diff --git a/source/core/StarSpline.hpp b/source/core/StarSpline.hpp new file mode 100644 index 0000000..4656eaa --- /dev/null +++ b/source/core/StarSpline.hpp @@ -0,0 +1,158 @@ +#ifndef STAR_SPLINE_HPP +#define STAR_SPLINE_HPP + +#include "StarVector.hpp" +#include "StarInterpolation.hpp" +#include "StarLogging.hpp" +#include "StarLruCache.hpp" + +namespace Star { + +// Implementation of DeCasteljau Algorithm for Bezier Curves +template <typename DataT, size_t Dimension, size_t Order, class PointT = Vector<DataT, Dimension>> +class Spline : public Array<PointT, Order + 1> { +public: + typedef Array<PointT, Order + 1> PointData; + + template <typename... T> + Spline(PointT const& e1, T const&... rest) + : PointData(e1, rest...) { + m_pointCache.setMaxSize(1000); + m_lengthCache.setMaxSize(1000); + } + + Spline() : PointData(PointData::filled(PointT())) { + m_pointCache.setMaxSize(1000); + m_lengthCache.setMaxSize(1000); + } + + PointT pointAt(float t) const { + float u = clamp<float>(t, 0, 1); + if (u != t) { + t = u; + Logger::warn("Passed out of range time to Spline::pointAt"); + } + + if (auto p = m_pointCache.ptr(t)) + return *p; + + PointData intermediates(*this); + PointData temp; + for (size_t order = Order + 1; order > 1; order--) { + for (size_t i = 1; i < order; i++) { + temp[i - 1] = lerp(t, intermediates[i - 1], intermediates[i]); + } + intermediates = std::move(temp); + } + + m_pointCache.set(t, intermediates[0]); + return intermediates[0]; + } + + PointT tangentAt(float t) const { + float u = clamp<float>(t, 0, 1); + if (u != t) { + t = u; + Logger::warn("Passed out of range time to Spline::tangentAt"); + } + + // constructs a hodograph and returns pointAt + Spline<DataT, Dimension, Order - 1> hodograph; + for (size_t i = 0; i < Order; i++) { + hodograph[i] = ((*this)[i + 1] - (*this)[i]) * Order; + } + return hodograph.pointAt(t); + } + + DataT length(float begin = 0, float end = 1, size_t subdivisions = 100) const { + if (!(begin <= 1 && begin >= 0 && end <= 1 && end >= 0 && begin <= end)) { + Logger::warn("Passed invalid range to Spline::length"); + return 0; + } + + if (!begin) { + if (auto p = m_lengthCache.ptr(end)) + return *p; + } + + DataT res = 0; + PointT previousPoint = pointAt(begin); + for (size_t i = 1; i <= subdivisions; i++) { + PointT currentPoint = pointAt(i / subdivisions * (end - begin)); + res += (currentPoint - previousPoint).magnitude(); + previousPoint = currentPoint; + } + + if (!begin) + m_lengthCache.set(end, res); + + return res; + } + + float arcLenPara(float u, DataT epsilon = .01) const { + if (u == 0) + return 0; + if (u == 1) + return 1; + u = clamp<float>(u, 0, 1); + if (u == 0 || u == 1) { + Logger::warn("Passed out of range time to Spline::arcLenPara"); + return u; + } + DataT targetLength = length() * u; + float t = .5; + float lower = 0; + float upper = 1; + DataT approxLen = length(0, t); + while (targetLength - approxLen > epsilon || targetLength - approxLen < -epsilon) { + if (targetLength > approxLen) { + lower = t; + } else { + upper = t; + } + t = (upper - lower) * .5 + lower; + approxLen = length(0, t); + } + return t; + } + + PointT& origin() { + m_pointCache.clear(); + m_lengthCache.clear(); + return (*this)[0]; + } + + PointT const& origin() const { + return (*this)[0]; + } + + PointT& dest() { + m_pointCache.clear(); + m_lengthCache.clear(); + return (*this)[Order]; + } + + PointT const& dest() const { + return (*this)[Order]; + } + + PointT& operator[](size_t index) { + m_pointCache.clear(); + m_lengthCache.clear(); + return PointData::operator[](index); + } + + PointT const& operator[](size_t index) const { + return PointData::operator[](index); + } + +protected: + mutable LruCache<float, PointT> m_pointCache; + mutable LruCache<float, DataT> m_lengthCache; +}; + +typedef Spline<float, 2, 3, Vec2F> CSplineF; + +} + +#endif diff --git a/source/core/StarStaticRandom.hpp b/source/core/StarStaticRandom.hpp new file mode 100644 index 0000000..97cfdef --- /dev/null +++ b/source/core/StarStaticRandom.hpp @@ -0,0 +1,131 @@ +#ifndef STAR_STATIC_RANDOM_HPP +#define STAR_STATIC_RANDOM_HPP + +#include "StarString.hpp" +#include "StarXXHash.hpp" + +namespace Star { + +// Cross-platform, predictable random number generators based on XXHash. +// Supports primitive types as well as strings for input data. + +inline void staticRandomHash32Iter(XXHash32&) {} + +template <typename T, typename... TL> +void staticRandomHash32Iter(XXHash32& hash, T const& v, TL const&... rest) { + xxHash32Push(hash, v); + staticRandomHash32Iter(hash, rest...); +} + +template <typename T, typename... TL> +uint32_t staticRandomHash32(T const& v, TL const&... rest) { + XXHash32 hash(2938728349u); + staticRandomHash32Iter(hash, v, rest...); + return hash.digest(); +} + +inline void staticRandomHash64Iter(XXHash64&) {} + +template <typename T, typename... TL> +void staticRandomHash64Iter(XXHash64& hash, T const& v, TL const&... rest) { + xxHash64Push(hash, v); + staticRandomHash64Iter(hash, rest...); +} + +template <typename T, typename... TL> +uint64_t staticRandomHash64(T const& v, TL const&... rest) { + XXHash64 hash(1997293021376312589); + staticRandomHash64Iter(hash, v, rest...); + return hash.digest(); +} + +template <typename T, typename... TL> +uint32_t staticRandomU32(T const& d, TL const&... rest) { + return staticRandomHash32(d, rest...); +} + +template <typename T, typename... TL> +uint64_t staticRandomU64(T const& d, TL const&... rest) { + return staticRandomHash64(d, rest...); +} + +template <typename T, typename... TL> +int32_t staticRandomI32(T const& d, TL const&... rest) { + return (int32_t)staticRandomU32(d, rest...); +} + +template <typename T, typename... TL> +int32_t staticRandomI32Range(int32_t min, int32_t max, T const& d, TL const&... rest) { + uint64_t denom = (uint64_t)(-1) / ((uint64_t)(max - min) + 1); + return (int32_t)(staticRandomU64(d, rest...) / denom + min); +} + +template <typename T, typename... TL> +uint32_t staticRandomU32Range(uint32_t min, uint32_t max, T const& d, TL const&... rest) { + uint64_t denom = (uint64_t)(-1) / ((uint64_t)(max - min) + 1); + return staticRandomU64(d, rest...) / denom + min; +} + +template <typename T, typename... TL> +int64_t staticRandomI64(T const& d, TL const&... rest) { + return (int64_t)staticRandomU64(d, rest...); +} + +// Generates values in the range [0.0, 1.0] +template <typename T, typename... TL> +float staticRandomFloat(T const& d, TL const&... rest) { + return (staticRandomU32(d, rest...) & 0x7fffffff) / 2147483648.0; +} + +template <typename T, typename... TL> +float staticRandomFloatRange(float min, float max, T const& d, TL const&... rest) { + return staticRandomFloat(d, rest...) * (max - min) + min; +} + +// Generates values in the range [0.0, 1.0] +template <typename T, typename... TL> +double staticRandomDouble(T const& d, TL const&... rest) { + return (staticRandomU64(d, rest...) & 0x7fffffffffffffff) / 9223372036854775808.0; +} + +template <typename T, typename... TL> +double staticRandomDoubleRange(double min, double max, T const& d, TL const&... rest) { + return staticRandomDouble(d, rest...) * (max - min) + min; +} + +template <typename Container, typename T, typename... TL> +typename Container::value_type& staticRandomFrom(Container& container, T const& d, TL const&... rest) { + auto i = container.begin(); + std::advance(i, staticRandomI32Range(0, container.size() - 1, d, rest...)); + return *i; +} + +template <typename Container, typename T, typename... TL> +typename Container::value_type const& staticRandomFrom(Container const& container, T const& d, TL const&... rest) { + auto i = container.begin(); + std::advance(i, staticRandomI32Range(0, container.size() - 1, d, rest...)); + return *i; +} + +template <typename Container, typename T, typename... TL> +typename Container::value_type staticRandomValueFrom(Container const& container, T const& d, TL const&... rest) { + if (container.empty()) { + return {}; + } else { + auto i = container.begin(); + std::advance(i, staticRandomI32Range(0, container.size() - 1, d, rest...)); + return *i; + } +} + +template <typename Container, typename T, typename... TL> +void staticRandomShuffle(Container& container, T const& d, TL const&... rest) { + int mix = 0; + std::random_shuffle(container.begin(), + container.end(), + [&](size_t max) { return staticRandomU32Range(0, max - 1, ++mix, d, rest...); }); +} + +} + +#endif diff --git a/source/core/StarStaticVector.hpp b/source/core/StarStaticVector.hpp new file mode 100644 index 0000000..693dc10 --- /dev/null +++ b/source/core/StarStaticVector.hpp @@ -0,0 +1,403 @@ +#ifndef STAR_STATIC_VECTOR_HPP +#define STAR_STATIC_VECTOR_HPP + +#include "StarException.hpp" + +namespace Star { + +STAR_EXCEPTION(StaticVectorSizeException, StarException); + +// Stack allocated vector of elements with a dynamic size which must be less +// than a given maximum. Acts like a vector with a built-in allocator of a +// maximum size, throws bad_alloc on attempting to resize beyond the maximum +// size. +template <typename Element, size_t MaxSize> +class StaticVector { +public: + typedef Element* iterator; + typedef Element const* const_iterator; + + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + + typedef Element value_type; + + typedef Element& reference; + typedef Element const& const_reference; + + static constexpr size_t MaximumSize = MaxSize; + + StaticVector(); + StaticVector(StaticVector const& other); + StaticVector(StaticVector&& other); + template <typename OtherElement, size_t OtherMaxSize> + StaticVector(StaticVector<OtherElement, OtherMaxSize> const& other); + template <class Iterator> + StaticVector(Iterator first, Iterator last); + StaticVector(size_t size, Element const& value = Element()); + StaticVector(initializer_list<Element> list); + ~StaticVector(); + + StaticVector& operator=(StaticVector const& other); + StaticVector& operator=(StaticVector&& other); + StaticVector& operator=(std::initializer_list<Element> list); + + size_t size() const; + bool empty() const; + void resize(size_t size, Element const& e = Element()); + + reference at(size_t i); + const_reference at(size_t i) const; + + reference operator[](size_t i); + const_reference operator[](size_t i) const; + + const_iterator begin() const; + const_iterator end() const; + + iterator begin(); + iterator end(); + + const_reverse_iterator rbegin() const; + const_reverse_iterator rend() const; + + reverse_iterator rbegin(); + reverse_iterator rend(); + + // Pointer to internal data, always valid even if empty. + Element const* ptr() const; + Element* ptr(); + + void push_back(Element e); + void pop_back(); + + iterator insert(iterator pos, Element e); + template <typename Iterator> + iterator insert(iterator pos, Iterator begin, Iterator end); + iterator insert(iterator pos, initializer_list<Element> list); + + template <typename... Args> + void emplace(iterator pos, Args&&... args); + + template <typename... Args> + void emplace_back(Args&&... args); + + void clear(); + + iterator erase(iterator pos); + iterator erase(iterator begin, iterator end); + + bool operator==(StaticVector const& other) const; + bool operator!=(StaticVector const& other) const; + bool operator<(StaticVector const& other) const; + +private: + size_t m_size; + typename std::aligned_storage<MaxSize * sizeof(Element), alignof(Element)>::type m_elements; +}; + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::StaticVector() + : m_size(0) {} + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::~StaticVector() { + clear(); +} + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::StaticVector(StaticVector const& other) + : StaticVector() { + insert(begin(), other.begin(), other.end()); +} + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::StaticVector(StaticVector&& other) + : StaticVector() { + for (auto& e : other) + emplace_back(move(e)); +} + +template <typename Element, size_t MaxSize> +template <typename OtherElement, size_t OtherMaxSize> +StaticVector<Element, MaxSize>::StaticVector(StaticVector<OtherElement, OtherMaxSize> const& other) + : StaticVector() { + for (auto const& e : other) + emplace_back(e); +} + +template <typename Element, size_t MaxSize> +template <class Iterator> +StaticVector<Element, MaxSize>::StaticVector(Iterator first, Iterator last) + : StaticVector() { + insert(begin(), first, last); +} + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::StaticVector(size_t size, Element const& value) + : StaticVector() { + resize(size, value); +} + +template <typename Element, size_t MaxSize> +StaticVector<Element, MaxSize>::StaticVector(initializer_list<Element> list) + : StaticVector() { + for (auto const& e : list) + emplace_back(e); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::operator=(StaticVector const& other) -> StaticVector& { + if (this == &other) + return *this; + + resize(other.size()); + for (size_t i = 0; i < m_size; ++i) + operator[](i) = other[i]; + + return *this; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::operator=(StaticVector&& other) -> StaticVector& { + resize(other.size()); + for (size_t i = 0; i < m_size; ++i) + operator[](i) = move(other[i]); + + return *this; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::operator=(std::initializer_list<Element> list) -> StaticVector& { + resize(list.size()); + for (size_t i = 0; i < m_size; ++i) + operator[](i) = move(list[i]); + return *this; +} + +template <typename Element, size_t MaxSize> +size_t StaticVector<Element, MaxSize>::size() const { + return m_size; +} + +template <typename Element, size_t MaxSize> +bool StaticVector<Element, MaxSize>::empty() const { + return m_size == 0; +} + +template <typename Element, size_t MaxSize> +void StaticVector<Element, MaxSize>::resize(size_t size, Element const& e) { + if (size > MaxSize) + throw StaticVectorSizeException::format("StaticVector::resize(%s) out of range %s", m_size + size, MaxSize); + + for (size_t i = m_size; i > size; --i) + pop_back(); + for (size_t i = m_size; i < size; ++i) + emplace_back(e); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::at(size_t i) -> reference { + if (i >= m_size) + throw OutOfRangeException::format("out of range in StaticVector::at(%s)", i); + return ptr()[i]; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::at(size_t i) const -> const_reference { + if (i >= m_size) + throw OutOfRangeException::format("out of range in StaticVector::at(%s)", i); + return ptr()[i]; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::operator[](size_t i) -> reference { + starAssert(i < m_size); + return ptr()[i]; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::operator[](size_t i) const -> const_reference { + starAssert(i < m_size); + return ptr()[i]; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::begin() const -> const_iterator { + return ptr(); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::end() const -> const_iterator { + return ptr() + m_size; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::begin() -> iterator { + return ptr(); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::end() -> iterator { + return ptr() + m_size; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::rbegin() const -> const_reverse_iterator { + return const_reverse_iterator(end()); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::rend() const -> const_reverse_iterator { + return const_reverse_iterator(begin()); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::rbegin() -> reverse_iterator { + return reverse_iterator(end()); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::rend() -> reverse_iterator { + return reverse_iterator(begin()); +} + +template <typename Element, size_t MaxSize> +Element const* StaticVector<Element, MaxSize>::ptr() const { + return (Element const*)&m_elements; +} + +template <typename Element, size_t MaxSize> +Element* StaticVector<Element, MaxSize>::ptr() { + return (Element*)&m_elements; +} + +template <typename Element, size_t MaxSize> +void StaticVector<Element, MaxSize>::push_back(Element e) { + emplace_back(move(e)); +} + +template <typename Element, size_t MaxSize> +void StaticVector<Element, MaxSize>::pop_back() { + if (m_size == 0) + throw OutOfRangeException("StaticVector::pop_back called on empty StaticVector"); + --m_size; + (ptr() + m_size)->~Element(); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::insert(iterator pos, Element e) -> iterator { + emplace(pos, move(e)); + return pos; +} + +template <typename Element, size_t MaxSize> +template <typename Iterator> +auto StaticVector<Element, MaxSize>::insert(iterator pos, Iterator begin, Iterator end) -> iterator { + size_t toAdd = std::distance(begin, end); + size_t startIndex = pos - ptr(); + size_t endIndex = startIndex + toAdd; + size_t toShift = m_size - startIndex; + + resize(m_size + toAdd); + + for (size_t i = toShift; i != 0; --i) + operator[](endIndex + i - 1) = move(operator[](startIndex + i - 1)); + + for (size_t i = 0; i != toAdd; ++i) + operator[](startIndex + i) = *begin++; + + return pos; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::insert(iterator pos, initializer_list<Element> list) -> iterator { + return insert(pos, list.begin(), list.end()); +} + +template <typename Element, size_t MaxSize> +template <typename... Args> +void StaticVector<Element, MaxSize>::emplace(iterator pos, Args&&... args) { + size_t index = pos - ptr(); + resize(m_size + 1); + for (size_t i = m_size - 1; i != index; --i) + operator[](i) = move(operator[](i - 1)); + operator[](index) = Element(forward<Args>(args)...); +} + +template <typename Element, size_t MaxSize> +template <typename... Args> +void StaticVector<Element, MaxSize>::emplace_back(Args&&... args) { + if (m_size + 1 > MaxSize) + throw StaticVectorSizeException::format("StaticVector::emplace_back would extend StaticVector beyond size %s", MaxSize); + + m_size += 1; + new (ptr() + m_size - 1) Element(forward<Args>(args)...); +} + +template <typename Element, size_t MaxSize> +void StaticVector<Element, MaxSize>::clear() { + while (m_size != 0) + pop_back(); +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::erase(iterator pos) -> iterator { + size_t index = pos - ptr(); + for (size_t i = index; i < m_size - 1; ++i) + operator[](i) = move(operator[](i + 1)); + resize(m_size - 1); + return pos; +} + +template <typename Element, size_t MaxSize> +auto StaticVector<Element, MaxSize>::erase(iterator begin, iterator end) -> iterator { + size_t startIndex = begin - ptr(); + size_t endIndex = end - ptr(); + size_t toRemove = endIndex - startIndex; + for (size_t i = endIndex; i < m_size; ++i) + operator[](startIndex + (i - endIndex)) = move(operator[](i)); + resize(m_size - toRemove); + return begin; +} + +template <typename Element, size_t MaxSize> +bool StaticVector<Element, MaxSize>::operator==(StaticVector const& other) const { + if (this == &other) + return true; + + if (m_size != other.m_size) + return false; + for (size_t i = 0; i < m_size; ++i) { + if (operator[](i) != other[i]) + return false; + } + return true; +} + +template <typename Element, size_t MaxSize> +bool StaticVector<Element, MaxSize>::operator!=(StaticVector const& other) const { + return !operator==(other); +} + +template <typename Element, size_t MaxSize> +bool StaticVector<Element, MaxSize>::operator<(StaticVector const& other) const { + for (size_t i = 0; i < m_size; ++i) { + if (i >= other.size()) + return false; + + Element const& a = operator[](i); + Element const& b = other[i]; + + if (a < b) + return true; + else if (b < a) + return false; + } + + return m_size < other.size(); +} + +} + +#endif diff --git a/source/core/StarString.cpp b/source/core/StarString.cpp new file mode 100644 index 0000000..81c6b2e --- /dev/null +++ b/source/core/StarString.cpp @@ -0,0 +1,1137 @@ +#include "StarString.hpp" +#include "StarBytes.hpp" +#include "StarFormat.hpp" + +#include <cctype> +#include <regex> + +namespace Star { + +bool String::isSpace(Char c) { + return + c == 0x20 || // space + c == 0x09 || // horizontal tab + c == 0x0a || // newline + c == 0x0d || // carriage return + c == 0xfeff; // BOM or ZWNBSP +} + +bool String::isAsciiNumber(Char c) { + return c >= '0' && c <= '9'; +} + +bool String::isAsciiLetter(Char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +String::Char String::toLower(Char c) { + if (c >= 'A' && c <= 'Z') + return c + 32; + else + return c; +} + +String::Char String::toUpper(Char c) { + if (c >= 'a' && c <= 'z') + return c - 32; + else + return c; +} + +bool String::charEqual(Char c1, Char c2, CaseSensitivity cs) { + if (cs == CaseInsensitive) + return toLower(c1) == toLower(c2); + else + return c1 == c2; +} + +String String::joinWith(String const& join, String const& left, String const& right) { + if (left.empty()) + return right; + if (right.empty()) + return left; + + if (left.endsWith(join)) { + if (right.beginsWith(join)) { + return left + right.substr(join.size()); + } + return left + right; + } else { + if (right.beginsWith(join)) { + return left + right; + } + return left + join + right; + } +} + +String::String() {} +String::String(String const& s) : m_string(s.m_string) {} +String::String(String&& s) : m_string(std::move(s.m_string)) {} +String::String(char const* s) : m_string(s) {} +String::String(char const* s, size_t n) : m_string(s, n) {} +String::String(std::string const& s) : m_string(s) {} +String::String(std::string&& s) : m_string(std::move(s)) {} + +String::String(std::wstring const& s) { + reserve(s.length()); + for (Char c : s) + append(c); +} + +String::String(Char const* s) { + while (*s) { + append(*s); + ++s; + } +} + +String::String(Char const* s, size_t n) { + reserve(n); + for (size_t idx = 0; idx < n; ++idx) { + append(*s); + ++s; + } +} + +String::String(Char c, size_t n) { + reserve(n); + for (size_t i = 0; i < n; ++i) + append(c); +} + +String::String(Char c) { + append(c); +} + +std::string const& String::utf8() const { + return m_string; +} + +std::string String::takeUtf8() { + return take(m_string); +} + +ByteArray String::utf8Bytes() const { + return ByteArray(m_string.c_str(), m_string.size()); +} + +char const* String::utf8Ptr() const { + return m_string.c_str(); +} + +size_t String::utf8Size() const { + return m_string.size(); +} + +std::wstring String::wstring() const { + std::wstring string; + for (Char c : *this) + string.push_back(c); + return string; +} + +String::WideString String::wideString() const { + WideString string; + string.reserve(m_string.size()); + for (Char c : *this) + string.push_back(c); + return string; +} + +String::const_iterator String::begin() const { + return const_iterator(m_string.begin()); +} + +String::const_iterator String::end() const { + return const_iterator(m_string.end()); +} + +size_t String::size() const { + return utf8Length(m_string.c_str(), m_string.size()); +} + +size_t String::length() const { + return size(); +} + +void String::clear() { + m_string.clear(); +} + +void String::reserve(size_t n) { + m_string.reserve(n); +} + +bool String::empty() const { + return m_string.empty(); +} + +String::Char String::operator[](size_t index) const { + auto it = begin(); + for (size_t i = 0; i < index; ++i) + ++it; + return *it; +} + +size_t CaseInsensitiveStringHash::operator()(String const& s) const { + PLHasher hash; + for (auto c : s) + hash.put(String::toLower(c)); + return hash.hash(); +} + +bool CaseInsensitiveStringCompare::operator()(String const& lhs, String const& rhs) const { + return lhs.equalsIgnoreCase(rhs); +} + +String::Char String::at(size_t i) const { + if (i > size()) + throw OutOfRangeException(strf("Out of range in String::at(%s)", i)); + return operator[](i); +} + +String String::toUpper() const { + String s; + s.reserve(m_string.length()); + for (Char c : *this) + s.append(toUpper(c)); + return s; +} + +String String::toLower() const { + String s; + s.reserve(m_string.length()); + for (Char c : *this) + s.append(toLower(c)); + return s; +} + +String String::titleCase() const { + String s; + s.reserve(m_string.length()); + bool capNext = true; + for (Char c : *this) { + if (capNext) + s.append(toUpper(c)); + else + s.append(toLower(c)); + capNext = !std::isalpha(c); + } + return s; +} + +bool String::endsWith(String const& end, CaseSensitivity cs) const { + auto endsize = end.size(); + if (endsize == 0) + return true; + + auto mysize = size(); + if (endsize > mysize) + return false; + + return compare(mysize - endsize, NPos, end, 0, NPos, cs) == 0; +} + +bool String::endsWith(Char end, CaseSensitivity cs) const { + if (size() == 0) + return false; + + return charEqual(end, operator[](size() - 1), cs); +} + +bool String::beginsWith(String const& beg, CaseSensitivity cs) const { + auto begSize = beg.size(); + if (begSize == 0) + return true; + + auto mysize = size(); + if (begSize > mysize) + return false; + + return compare(0, begSize, beg, 0, NPos, cs) == 0; +} + +bool String::beginsWith(Char beg, CaseSensitivity cs) const { + if (size() == 0) + return false; + + return charEqual(beg, operator[](0), cs); +} + +String String::reverse() const { + String ret; + ret.reserve(m_string.length()); + auto i = end(); + while (i != begin()) { + --i; + ret.append(*i); + } + return ret; +} + +String String::rot13() const { + String ret; + ret.reserve(m_string.length()); + for (auto c : *this) { + if ((c >= 'a' && c <= 'm') || (c >= 'A' && c <= 'M')) + c += 13; + else if ((c >= 'n' && c <= 'z') || (c >= 'N' && c <= 'Z')) + c -= 13; + ret.append(c); + } + return ret; +} + +StringList String::split(Char c, size_t maxSplit) const { + return split(String(c), maxSplit); +} + +StringList String::split(String const& pattern, size_t maxSplit) const { + StringList ret; + if (pattern.empty()) + return StringList(1, *this); + + size_t beg = 0; + while (true) { + if (ret.size() == maxSplit) { + ret.append(m_string.substr(beg)); + break; + } + + size_t end = m_string.find(pattern.m_string, beg); + if (end == NPos) { + ret.append(m_string.substr(beg)); + break; + } + ret.append(m_string.substr(beg, end - beg)); + beg = end + pattern.m_string.size(); + } + + starAssert(maxSplit == NPos || ret.size() <= maxSplit + 1); + return ret; +} + +StringList String::rsplit(Char c, size_t maxSplit) const { + return rsplitAny(String(c), maxSplit); +} + +StringList String::rsplit(String const& pattern, size_t maxSplit) const { + // This is really inefficient! + String v = reverse(); + String p = pattern.reverse(); + StringList l = v.split(p, maxSplit); + for (auto& s : l) + s = s.reverse(); + + Star::reverse(l); + return l; +} + +StringList String::splitAny(String const& chars, size_t maxSplit) const { + StringList ret; + String next; + bool doneSplitting = false; + for (auto c : *this) { + if (!doneSplitting && chars.hasCharOrWhitespace(c)) { + if (!next.empty()) + ret.append(take(next)); + } else { + if (ret.size() == maxSplit) + doneSplitting = true; + next.append(c); + } + } + if (!next.empty()) + ret.append(move(next)); + return ret; +} + +StringList String::rsplitAny(String const& chars, size_t maxSplit) const { + // This is really inefficient! + String v = reverse(); + StringList l = v.splitAny(chars, maxSplit); + for (auto& s : l) + s = s.reverse(); + + Star::reverse(l); + return l; +} + +StringList String::splitLines(size_t maxSplit) const { + return splitAny("\r\n", maxSplit); +} + +StringList String::splitWhitespace(size_t maxSplit) const { + return splitAny("", maxSplit); +} + +String String::extract(String const& chars) { + StringList l = splitAny(chars, 1); + if (l.size() == 0) { + return String(); + } else if (l.size() == 1) { + clear(); + return l.at(0); + } else { + *this = l.at(1); + return l.at(0); + } +} + +String String::rextract(String const& chars) { + if (empty()) + return String(); + + StringList l = rsplitAny(chars, 1); + if (l.size() == 1) { + clear(); + return l.at(0); + } else { + *this = l.at(0); + return l.at(1); + } +} + +bool String::hasChar(Char c) const { + for (Char ch : *this) + if (ch == c) + return true; + return false; +} + +bool String::hasCharOrWhitespace(Char c) const { + if (empty()) + return isSpace(c); + else + return hasChar(c); +} + +String String::replace(String const& rplc, String const& val) const { + size_t index; + size_t sz = size(); + size_t rsz = rplc.size(); + String ret; + ret.reserve(m_string.length()); + + if (rplc.empty()) + return *this; + + index = find(rplc); + if (index == NPos) + return *this; + + auto it = begin(); + for (size_t i = 0; i < index; ++i) + ret.append(*it++); + + while (index < sz) { + ret.append(val); + index += rsz; + for (size_t i = 0; i < rsz; ++i) + ++it; + + size_t nindex = find(rplc, index); + for (size_t i = index; i < nindex && i < sz; ++i) + ret.append(*it++); + + index = nindex; + } + return ret; +} + +String String::trimEnd(String const& pattern) const { + size_t end; + for (end = size(); end > 0; --end) { + Char ec = (*this)[end - 1]; + if (!pattern.hasCharOrWhitespace(ec)) + break; + } + return substr(0, end); +} + +String String::trimBeg(String const& pattern) const { + size_t beg; + for (beg = 0; beg < size(); ++beg) { + Char bc = (*this)[beg]; + if (!pattern.hasCharOrWhitespace(bc)) + break; + } + return substr(beg); +} + +String String::trim(String const& pattern) const { + return trimEnd(pattern).trimBeg(pattern); +} + +size_t String::find(Char c, size_t pos, CaseSensitivity cs) const { + auto it = begin(); + for (size_t i = 0; i < pos; ++i) { + if (it == end()) + break; + ++it; + } + + while (it != end()) { + if (charEqual(c, *it, cs)) + return pos; + ++pos; + ++it; + } + + return NPos; +} + +size_t String::find(String const& str, size_t pos, CaseSensitivity cs) const { + if (str.empty()) + return 0; + + auto it = begin(); + for (size_t i = 0; i < pos; ++i) { + if (it == end()) + break; + ++it; + } + + const_iterator sit = str.begin(); + const_iterator mit = it; + while (it != end()) { + if (charEqual(*sit, *mit, cs)) { + do { + ++mit; + ++sit; + if (sit == str.end()) + return pos; + else if (mit == end()) + break; + } while (charEqual(*sit, *mit, cs)); + sit = str.begin(); + } + ++pos; + mit = ++it; + } + + return NPos; +} + +size_t String::findLast(Char c, CaseSensitivity cs) const { + auto it = begin(); + + size_t found = NPos; + size_t pos = 0; + while (it != end()) { + if (charEqual(c, *it, cs)) + found = pos; + ++pos; + ++it; + } + + return found; +} + +size_t String::findLast(String const& str, CaseSensitivity cs) const { + if (str.empty()) + return 0; + + size_t pos = 0; + auto it = begin(); + size_t result = NPos; + const_iterator sit = str.begin(); + const_iterator mit = it; + while (it != end()) { + if (charEqual(*sit, *mit, cs)) { + do { + ++mit; + ++sit; + if (sit == str.end()) { + result = pos; + break; + } + if (mit == end()) + break; + } while (charEqual(*sit, *mit, cs)); + sit = str.begin(); + } + ++pos; + mit = ++it; + } + + return result; +} + +size_t String::findFirstOf(String const& pattern, size_t beg) const { + auto it = begin(); + size_t i; + for (i = 0; i < beg; ++i) + ++it; + + while (it != end()) { + if (pattern.hasCharOrWhitespace(*it)) + return i; + ++it; + ++i; + } + return NPos; +} + +size_t String::findFirstNotOf(String const& pattern, size_t beg) const { + auto it = begin(); + size_t i; + for (i = 0; i < beg; ++i) + ++it; + + while (it != end()) { + if (!pattern.hasCharOrWhitespace(*it)) + return i; + ++it; + ++i; + } + return NPos; +} + +size_t String::findNextBoundary(size_t index, bool backwards) const { + starAssert(index <= size()); + if (!backwards && (index == size())) + return index; + if (backwards) { + if (index == 0) + return 0; + index--; + } + Char c = this->at(index); + while (!isSpace(c)) { + if (backwards && (index == 0)) + return 0; + index += backwards ? -1 : 1; + if (index == size()) + return size(); + c = this->at(index); + } + while (isSpace(c)) { + if (backwards && (index == 0)) + return 0; + index += backwards ? -1 : 1; + if (index == size()) + return size(); + c = this->at(index); + } + if (backwards && !(index == size())) + return index + 1; + return index; +} + +String String::slice(SliceIndex a, SliceIndex b, int i) const { + auto wide = wideString(); + wide = Star::slice(wide, a, b, i); + return String(wide.c_str()); +} + +void String::append(String const& string) { + m_string.append(string.m_string); +} + +void String::append(std::string const& s) { + m_string.append(s); +} + +void String::append(Char const* s) { + while (s) + append(*s++); +} + +void String::append(Char const* s, size_t n) { + for (size_t i = 0; i < n; ++i) + append(s[i]); +} + +void String::append(char const* s) { + m_string.append(s); +} + +void String::append(char const* s, size_t n) { + m_string.append(s, n); +} + +void String::append(Char c) { + char conv[6]; + size_t size = utf8EncodeChar(conv, c, 6); + append(conv, size); +} + +void String::prepend(String const& s) { + auto ns = s; + ns.append(*this); + *this = move(ns); +} + +void String::prepend(std::string const& s) { + auto ns = String(s); + ns.append(*this); + *this = move(ns); +} + +void String::prepend(Char const* s) { + auto ns = String(s); + ns.append(*this); + *this = move(ns); +} + +void String::prepend(Char const* s, size_t n) { + auto ns = String(s, n); + ns.append(*this); + *this = move(ns); +} + +void String::prepend(char const* s) { + auto ns = String(s); + ns.append(*this); + *this = move(ns); +} + +void String::prepend(char const* s, size_t n) { + auto ns = String(s, n); + ns.append(*this); + *this = move(ns); +} + +void String::prepend(Char c) { + auto ns = String(c, 1); + ns.append(*this); + *this = move(ns); +} + +void String::push_back(Char c) { + append(c); +} + +void String::push_front(Char c) { + prepend(c); +} + +bool String::contains(String const& s, CaseSensitivity cs) const { + return find(s, 0, cs) != NPos; +} + +bool String::regexMatch(String const& regex, bool full, bool caseSensitive) const { + if (full) { + if (caseSensitive) + return std::regex_match(utf8(), std::regex(regex.utf8())); + else + return std::regex_match(utf8(), std::regex(regex.utf8(), std::regex::icase)); + } else { + if (caseSensitive) + return std::regex_search(utf8(), std::regex(regex.utf8())); + else + return std::regex_search(utf8(), std::regex(regex.utf8(), std::regex::icase)); + } +} + +int String::compare(String const& s, CaseSensitivity cs) const { + if (cs == CaseSensitive) + return m_string.compare(s.m_string); + else + return compare(0, NPos, s, 0, NPos, cs); +} + +bool String::equals(String const& s, CaseSensitivity cs) const { + return compare(s, cs) == 0; +} + +bool String::equalsIgnoreCase(String const& s) const { + return compare(s, CaseInsensitive) == 0; +} + +String String::substr(size_t position, size_t n) const { + auto len = size(); + if (position > len) + throw OutOfRangeException(strf("out of range in String::substr(%s, %s)", position, n)); + + if (position == 0 && n >= len) + return *this; + + String ret; + ret.reserve(std::min(n, len - position)); + + auto it = begin(); + std::advance(it, position); + + for (size_t i = 0; i < n; ++i) { + if (it == end()) + break; + ret.append(*it); + ++it; + } + + return ret; +} + +void String::erase(size_t pos, size_t n) { + String ns; + ns.reserve(m_string.size() - std::min(n, m_string.size())); + auto it = begin(); + for (size_t i = 0; i < pos; ++i) + ns.append(*it++); + for (size_t i = 0; i < n; ++i) { + if (it == end()) + break; + ++it; + } + while (it != end()) + ns.append(*it++); + *this = ns; +} + +String String::padLeft(size_t size, String const& filler) const { + if (!filler.length()) + return *this; + String rs; + while (rs.length() + length() < size) { + rs.append(filler); + } + if (rs.length()) + return rs + *this; + return *this; +} + +String String::padRight(size_t size, String const& filler) const { + if (!filler.length()) + return *this; + String rs = *this; + while (rs.length() < size) { + rs.append(filler); + } + return rs; +} + +String& String::operator=(String const& s) { + m_string = s.m_string; + return *this; +} + +String& String::operator=(String&& s) { + m_string = move(s.m_string); + return *this; +} + +String& String::operator+=(String const& s) { + append(s); + return *this; +} + +String& String::operator+=(std::string const& s) { + append(s); + return *this; +} + +String& String::operator+=(Char const* s) { + append(s); + return *this; +} + +String& String::operator+=(char const* s) { + append(s); + return *this; +} + +String& String::operator+=(Char c) { + append(c); + return *this; +} + +bool operator==(String const& s1, String const& s2) { + return s1.m_string == s2.m_string; +} + +bool operator==(String const& s1, std::string const& s2) { + return s1.m_string == s2; +} + +bool operator==(String const& s1, String::Char const* s2) { + return s1 == String(s2); +} + +bool operator==(String const& s1, char const* s2) { + return s1.m_string == s2; +} + +bool operator==(std::string const& s1, String const& s2) { + return s1 == s2.m_string; +} + +bool operator==(String::Char const* s1, String const& s2) { + return String(s1) == s2; +} + +bool operator==(char const* s1, String const& s2) { + return s1 == s2.m_string; +} + +bool operator!=(String const& s1, String const& s2) { + return s1.m_string != s2.m_string; +} + +bool operator!=(String const& s1, std::string const& s2) { + return s1.m_string != s2; +} + +bool operator!=(String const& s1, String::Char const* s2) { + return s1 != String(s2); +} + +bool operator!=(String const& s1, char const* s2) { + return s1.m_string != s2; +} + +bool operator!=(std::string const& s1, String const& s2) { + return s1 != s2.m_string; +} + +bool operator!=(String::Char const* s1, String const& s2) { + return String(s1) != s2; +} + +bool operator!=(char const* s1, String const& s2) { + return s1 != s2.m_string; +} + +bool operator<(String const& s1, String const& s2) { + return s1.m_string < s2.m_string; +} + +bool operator<(String const& s1, std::string const& s2) { + return s1.m_string < s2; +} + +bool operator<(String const& s1, String::Char const* s2) { + return s1 < String(s2); +} + +bool operator<(String const& s1, char const* s2) { + return s1.m_string < s2; +} + +bool operator<(std::string const& s1, String const& s2) { + return s1 < s2.m_string; +} + +bool operator<(String::Char const* s1, String const& s2) { + return String(s1) < s2; +} + +bool operator<(char const* s1, String const& s2) { + return s1 < s2.m_string; +} + +String operator+(String s1, String const& s2) { + s1.append(s2); + return s1; +} + +String operator+(String s1, std::string const& s2) { + s1.append(s2); + return s1; +} + +String operator+(String s1, String::Char const* s2) { + s1.append(s2); + return s1; +} + +String operator+(String s1, char const* s2) { + s1.append(s2); + return s1; +} + +String operator+(std::string const& s1, String const& s2) { + return s1 + s2.m_string; +} + +String operator+(String::Char const* s1, String const& s2) { + return String(s1) + s2; +} + +String operator+(char const* s1, String const& s2) { + return s1 + s2.m_string; +} + +String operator+(String s, String::Char c) { + s.append(c); + return s; +} + +String operator+(String::Char c, String const& s) { + String res(c); + res.append(s); + return res; +} + +String operator*(String const& s, unsigned times) { + String res; + for (unsigned i = 0; i < times; ++i) + res.append(s); + return res; +} + +String operator*(unsigned times, String const& s) { + return s * times; +} + +std::ostream& operator<<(std::ostream& os, String const& s) { + os << s.utf8(); + return os; +} + +std::istream& operator>>(std::istream& is, String& s) { + std::string temp; + is >> temp; + s = String(std::move(temp)); + return is; +} + +int String::compare(size_t selfOffset, size_t selfLen, String const& other, + size_t otherOffset, size_t otherLen, CaseSensitivity cs) const { + auto selfIt = begin(); + auto otherIt = other.begin(); + + while (selfOffset > 0 && selfIt != end()) { + ++selfIt; + --selfOffset; + } + + while (otherOffset > 0 && otherIt != other.end()) { + ++otherIt; + --otherLen; + } + + while (true) { + if ((selfIt == end() || selfLen == 0) && (otherIt == other.end() || otherLen == 0)) + return 0; + else if (selfIt == end() || selfLen == 0) + return -1; + else if (otherIt == other.end() || otherLen == 0) + return 1; + + auto c1 = *selfIt; + auto c2 = *otherIt; + + if (cs == CaseInsensitive) { + c1 = toLower(c1); + c2 = toLower(c2); + } + + if (c1 < c2) + return -1; + else if (c2 < c1) + return 1; + + ++selfIt; + ++otherIt; + --selfLen; + --otherLen; + } +} + +StringList::StringList() : Base() {} + +StringList::StringList(Base const& l) : Base(l) {} + +StringList::StringList(Base&& l) : Base(std::move(l)) {} + +StringList::StringList(StringList const& l) : Base(l) {} + +StringList::StringList(StringList&& l) : Base(std::move(l)) {} + +StringList::StringList(size_t n, String::Char const* const* list) { + for (size_t i = 0; i < n; ++i) + append(String(list[i])); +} + +StringList::StringList(size_t n, char const* const* list) { + for (size_t i = 0; i < n; ++i) + append(String(list[i])); +} + +StringList::StringList(size_t len, String const& s1) : Base(len, s1) {} + +StringList::StringList(std::initializer_list<String> list) : Base(list) {} + +StringList& StringList::operator=(Base const& rhs) { + Base::operator=(rhs); + return *this; +} + +StringList& StringList::operator=(Base&& rhs) { + Base::operator=(std::move(rhs)); + return *this; +} + +StringList& StringList::operator=(StringList const& rhs) { + Base::operator=(rhs); + return *this; +} + +StringList& StringList::operator=(StringList&& rhs) { + Base::operator=(std::move(rhs)); + return *this; +} + +StringList& StringList::operator=(initializer_list<String> list) { + Base::operator=(std::move(list)); + return *this; +} + +bool StringList::contains(String const& s, String::CaseSensitivity cs) const { + for (const_iterator i = begin(); i != end(); ++i) { + if (s.compare(*i, cs) == 0) + return true; + } + return false; +} + +StringList StringList::trimAll(String const& pattern) const { + StringList r; + for (auto const& s : *this) + r.append(s.trim(pattern)); + return r; +} + +String StringList::join(String const& separator) const { + String joinedString; + for (const_iterator i = begin(); i != end(); ++i) { + if (i != begin()) + joinedString += separator; + joinedString += *i; + } + + return joinedString; +} + +StringList StringList::slice(SliceIndex a, SliceIndex b, int i) const { + return Star::slice(*this, a, b, i); +} + +StringList StringList::sorted() const { + StringList l = *this; + l.sort(); + return l; +} + +std::ostream& operator<<(std::ostream& os, const StringList& list) { + os << "("; + for (auto i = list.begin(); i != list.end(); ++i) { + if (i != list.begin()) + os << ", "; + + os << '\'' << *i << '\''; + } + os << ")"; + return os; +} + +size_t hash<StringList>::operator()(StringList const& sl) const { + size_t h = 0; + for (auto const& s : sl) + hashCombine(h, hashOf(s)); + return h; +} + +} diff --git a/source/core/StarString.hpp b/source/core/StarString.hpp new file mode 100644 index 0000000..0f30fd2 --- /dev/null +++ b/source/core/StarString.hpp @@ -0,0 +1,462 @@ +#ifndef STAR_STRING_HPP +#define STAR_STRING_HPP + +#include "StarUnicode.hpp" +#include "StarHash.hpp" +#include "StarByteArray.hpp" +#include "StarList.hpp" +#include "StarMap.hpp" +#include "StarSet.hpp" + +namespace Star { + +STAR_CLASS(StringList); +STAR_CLASS(String); + +STAR_EXCEPTION(StringException, StarException); + +// A Unicode string class, which is a basic UTF-8 aware wrapper around +// std::string. Provides methods for accessing UTF-32 "Char" type, which +// provides access to each individual code point. Printing, hashing, copying, +// and in-order access should be basically as fast as std::string, but the more +// complex string processing methods may be much worse. +// +// All case sensitive / insensitive functionality is based on ASCII tolower and +// toupper, and will have no effect on characters outside ASCII. Therefore, +// case insensitivity is really only appropriate for code / script processing, +// not for general strings. +class String { +public: + typedef Utf32Type Char; + + // std::basic_string equivalent that guarantees const access time for + // operator[], etc + typedef std::basic_string<Char> WideString; + + typedef U8ToU32Iterator<std::string::const_iterator> const_iterator; + typedef Char value_type; + typedef value_type const& const_reference; + + enum CaseSensitivity { + CaseSensitive, + CaseInsensitive + }; + + // Space, horizontal tab, newline, carriage return, and BOM / ZWNBSP + static bool isSpace(Char c); + static bool isAsciiNumber(Char c); + static bool isAsciiLetter(Char c); + + // These methods only actually work on unicode characters below 127, i.e. + // ASCII subset. + static Char toLower(Char c); + static Char toUpper(Char c); + static bool charEqual(Char c1, Char c2, CaseSensitivity cs); + + // Join two strings together with a joiner, so that only one instance of the + // joiner is in between the left and right strings. For example, joins "foo" + // and "bar" with "?" to produce "foo?bar". Gets rid of repeat joiners, so + // "foo?" and "?bar" with "?" also becomes "foo?bar". Also, if left or right + // is empty, does not add a joiner, for example "" and "baz" joined with "?" + // produces "baz". + static String joinWith(String const& join, String const& left, String const& right); + template <typename... StringType> + static String joinWith(String const& join, String const& first, String const& second, String const& third, StringType const&... rest); + + String(); + String(String const& s); + String(String&& s); + + // These assume utf8 input + String(char const* s); + String(char const* s, size_t n); + String(std::string const& s); + String(std::string&& s); + + String(std::wstring const& s); + String(Char const* s); + String(Char const* s, size_t n); + String(Char c, size_t n); + + explicit String(Char c); + + // const& to internal utf8 data + std::string const& utf8() const; + std::string takeUtf8(); + ByteArray utf8Bytes() const; + // Pointer to internal utf8 data, null-terminated. + char const* utf8Ptr() const; + size_t utf8Size() const; + + std::wstring wstring() const; + WideString wideString() const; + + const_iterator begin() const; + const_iterator end() const; + + size_t size() const; + size_t length() const; + + void clear(); + void reserve(size_t n); + bool empty() const; + + Char operator[](size_t i) const; + // Throws StringException if i out of range. + Char at(size_t i) const; + + String toUpper() const; + String toLower() const; + String titleCase() const; + + bool endsWith(String const& end, CaseSensitivity cs = CaseSensitive) const; + bool endsWith(Char end, CaseSensitivity cs = CaseSensitive) const; + bool beginsWith(String const& beg, CaseSensitivity cs = CaseSensitive) const; + bool beginsWith(Char beg, CaseSensitivity cs = CaseSensitive) const; + + String reverse() const; + + String rot13() const; + + StringList split(Char c, size_t maxSplit = NPos) const; + StringList split(String const& pattern, size_t maxSplit = NPos) const; + StringList rsplit(Char c, size_t maxSplit = NPos) const; + StringList rsplit(String const& pattern, size_t maxSplit = NPos) const; + + // Splits on any number of contiguous instances of any of the given + // characters. Behaves differently than regular split in that leading and + // trailing instances of the characters are also ignored, and in general no + // empty strings will be in the resulting split list. If chars is empty, + // then splits on any whitespace. + StringList splitAny(String const& chars = "", size_t maxSplit = NPos) const; + StringList rsplitAny(String const& chars = "", size_t maxSplit = NPos) const; + + // Split any with '\n\r' + StringList splitLines(size_t maxSplit = NPos) const; + // Shorthand for splitAny(""); + StringList splitWhitespace(size_t maxSplit = NPos) const; + + // Splits a string once based on the given characters (defaulting to + // whitespace), and returns the first part. This string is set to the + // second part. + String extract(String const& chars = ""); + String rextract(String const& chars = ""); + + bool hasChar(Char c) const; + // Identical to hasChar, except, if string is empty, tests if c is + // whitespace. + bool hasCharOrWhitespace(Char c) const; + + String replace(String const& rplc, String const& val) const; + + String trimEnd(String const& chars = "") const; + String trimBeg(String const& chars = "") const; + String trim(String const& chars = "") const; + + size_t find(Char c, size_t beg = 0, CaseSensitivity cs = CaseSensitive) const; + size_t find(String const& s, size_t beg = 0, CaseSensitivity cs = CaseSensitive) const; + size_t findLast(Char c, CaseSensitivity cs = CaseSensitive) const; + size_t findLast(String const& s, CaseSensitivity cs = CaseSensitive) const; + + // If pattern is empty, finds first whitespace + size_t findFirstOf(String const& chars = "", size_t beg = 0) const; + + // If pattern is empty, finds first non-whitespace + size_t findFirstNotOf(String const& chars = "", size_t beg = 0) const; + + // finds the the start of the next 'boundary' in a string. used for quickly + // scanning a string + size_t findNextBoundary(size_t index, bool backwards = false) const; + + String slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + void append(String const& s); + void append(std::string const& s); + void append(Char const* s); + void append(Char const* s, size_t n); + void append(char const* s); + void append(char const* s, size_t n); + void append(Char c); + + void prepend(String const& s); + void prepend(std::string const& s); + void prepend(Char const* s); + void prepend(Char const* s, size_t n); + void prepend(char const* s); + void prepend(char const* s, size_t n); + void prepend(Char c); + + void push_back(Char c); + void push_front(Char c); + + bool contains(String const& s, CaseSensitivity cs = CaseSensitive) const; + + // Does this string match the given regular expression? + bool regexMatch(String const& regex, bool full = true, bool caseSensitive = true) const; + + int compare(String const& s, CaseSensitivity cs = CaseSensitive) const; + bool equals(String const& s, CaseSensitivity cs = CaseSensitive) const; + // Synonym for equals(s, String::CaseInsensitive) + bool equalsIgnoreCase(String const& s) const; + + String substr(size_t position, size_t n = NPos) const; + void erase(size_t pos = 0, size_t n = NPos); + + String padLeft(size_t size, String const& filler) const; + String padRight(size_t size, String const& filler) const; + + // Replace angle bracket tags in the string with values given by the given + // lookup function. Will be called as: + // String lookup(String const& key); + template <typename Lookup> + String lookupTags(Lookup&& lookup) const; + + // Replace angle bracket tags in the string with values given by the tags + // map. If replaceWithDefault is true, then values that are not found in the + // tags map are replace with the default string. If replaceWithDefault is + // false, tags that are not found are not replaced at all. + template <typename MapType> + String replaceTags(MapType const& tags, bool replaceWithDefault = false, String defaultValue = "") const; + + String& operator=(String const& s); + String& operator=(String&& s); + + String& operator+=(String const& s); + String& operator+=(std::string const& s); + String& operator+=(Char const* s); + String& operator+=(char const* s); + String& operator+=(Char c); + + friend bool operator==(String const& s1, String const& s2); + friend bool operator==(String const& s1, std::string const& s2); + friend bool operator==(String const& s1, Char const* s2); + friend bool operator==(String const& s1, char const* s2); + friend bool operator==(std::string const& s1, String const& s2); + friend bool operator==(Char const* s1, String const& s2); + friend bool operator==(char const* s1, String const& s2); + + friend bool operator!=(String const& s1, String const& s2); + friend bool operator!=(String const& s1, std::string const& s2); + friend bool operator!=(String const& s1, Char const* s2); + friend bool operator!=(String const& s1, char const* c); + friend bool operator!=(std::string const& s1, String const& s2); + friend bool operator!=(Char const* s1, String const& s2); + friend bool operator!=(char const* s1, String const& s2); + + friend bool operator<(String const& s1, String const& s2); + friend bool operator<(String const& s1, std::string const& s2); + friend bool operator<(String const& s1, Char const* s2); + friend bool operator<(String const& s1, char const* s2); + friend bool operator<(std::string const& s1, String const& s2); + friend bool operator<(Char const* s1, String const& s2); + friend bool operator<(char const* s1, String const& s2); + + friend String operator+(String s1, String const& s2); + friend String operator+(String s1, std::string const& s2); + friend String operator+(String s1, Char const* s2); + friend String operator+(String s1, char const* s2); + friend String operator+(std::string const& s1, String const& s2); + friend String operator+(Char const* s1, String const& s2); + friend String operator+(char const* s1, String const& s2); + + friend String operator+(String s, Char c); + friend String operator+(Char c, String const& s); + + friend String operator*(String const& s, unsigned times); + friend String operator*(unsigned times, String const& s); + + friend std::ostream& operator<<(std::ostream& os, String const& s); + friend std::istream& operator>>(std::istream& is, String& s); + +private: + int compare(size_t selfOffset, + size_t selfLen, + String const& other, + size_t otherOffset, + size_t otherLen, + CaseSensitivity cs) const; + + std::string m_string; +}; + +class StringList : public List<String> { +public: + typedef List<String> Base; + + typedef Base::iterator iterator; + typedef Base::const_iterator const_iterator; + typedef Base::value_type value_type; + typedef Base::reference reference; + typedef Base::const_reference const_reference; + + template <typename Container> + static StringList from(Container const& m); + + StringList(); + StringList(Base const& l); + StringList(Base&& l); + StringList(StringList const& l); + StringList(StringList&& l); + StringList(size_t len, String::Char const* const* list); + StringList(size_t len, char const* const* list); + explicit StringList(size_t len, String const& s1 = String()); + StringList(std::initializer_list<String> list); + + template <typename InputIterator> + StringList(InputIterator beg, InputIterator end) + : Base(beg, end) {} + + StringList& operator=(Base const& rhs); + StringList& operator=(Base&& rhs); + StringList& operator=(StringList const& rhs); + StringList& operator=(StringList&& rhs); + StringList& operator=(initializer_list<String> list); + + bool contains(String const& s, String::CaseSensitivity cs = String::CaseSensitive) const; + StringList trimAll(String const& chars = "") const; + String join(String const& separator = "") const; + + StringList slice(SliceIndex a = SliceIndex(), SliceIndex b = SliceIndex(), int i = 1) const; + + template <typename Filter> + StringList filtered(Filter&& filter) const; + + template <typename Comparator> + StringList sorted(Comparator&& comparator) const; + + StringList sorted() const; +}; + +std::ostream& operator<<(std::ostream& os, StringList const& list); + +template <> +struct hash<String> { + size_t operator()(String const& s) const; +}; + +struct CaseInsensitiveStringHash { + size_t operator()(String const& s) const; +}; + +struct CaseInsensitiveStringCompare { + bool operator()(String const& lhs, String const& rhs) const; +}; + +typedef HashSet<String> StringSet; + +template <typename MappedT, typename HashT = hash<String>, typename ComparatorT = std::equal_to<String>> +using StringMap = HashMap<String, MappedT, HashT, ComparatorT>; + +template <typename MappedT, typename HashT = hash<String>, typename ComparatorT = std::equal_to<String>> +using StableStringMap = StableHashMap<String, MappedT, HashT, ComparatorT>; + +template <typename MappedT> +using CaseInsensitiveStringMap = StringMap<MappedT, CaseInsensitiveStringHash, CaseInsensitiveStringCompare>; + +template <> +struct hash<StringList> { + size_t operator()(StringList const& s) const; +}; + +template <typename... StringType> +String String::joinWith( + String const& join, String const& first, String const& second, String const& third, StringType const&... rest) { + return joinWith(join, joinWith(join, first, second), third, rest...); +} + +template <typename Lookup> +String String::lookupTags(Lookup&& lookup) const { + // Operates directly on the utf8 representation of the strings, rather than + // using unicode find / replace methods + + auto substrInto = [](std::string const& ref, size_t position, size_t n, std::string& result) { + auto len = ref.size(); + if (position > len) + throw OutOfRangeException(strf("out of range in substrInto: %s", position)); + + auto it = ref.begin(); + std::advance(it, position); + + for (size_t i = 0; i < n; ++i) { + if (it == ref.end()) + break; + result.push_back(*it); + ++it; + } + }; + + std::string finalString; + + size_t start = 0; + size_t size = String::size(); + + finalString.reserve(size); + + String key; + + while (true) { + if (start >= size) + break; + + size_t beginTag = m_string.find("<", start); + size_t endTag = m_string.find(">", beginTag); + if (beginTag != NPos && endTag != NPos) { + substrInto(m_string, beginTag + 1, endTag - beginTag - 1, key.m_string); + substrInto(m_string, start, beginTag - start, finalString); + finalString += lookup(key).m_string; + key.m_string.clear(); + start = endTag + 1; + + } else { + substrInto(m_string, start, NPos, finalString); + break; + } + } + + return move(finalString); +} + +template <typename MapType> +String String::replaceTags(MapType const& tags, bool replaceWithDefault, String defaultValue) const { + return lookupTags([&](String const& key) -> String { + auto i = tags.find(key); + if (i == tags.end()) { + if (replaceWithDefault) + return defaultValue; + else + return "<" + key + ">"; + } else { + return i->second; + } + }); +} + +inline size_t hash<String>::operator()(String const& s) const { + PLHasher hash; + for (auto c : s.utf8()) + hash.put(c); + return hash.hash(); +} + +template <typename Container> +StringList StringList::from(Container const& m) { + return StringList(m.begin(), m.end()); +} + +template <typename Filter> +StringList StringList::filtered(Filter&& filter) const { + StringList l; + l.filter(forward<Filter>(filter)); + return l; +} + +template <typename Comparator> +StringList StringList::sorted(Comparator&& comparator) const { + StringList l; + l.sort(forward<Comparator>(comparator)); + return l; +} + +} + +#endif diff --git a/source/core/StarString_windows.cpp b/source/core/StarString_windows.cpp new file mode 100644 index 0000000..77c3e89 --- /dev/null +++ b/source/core/StarString_windows.cpp @@ -0,0 +1,34 @@ +#include "StarString_windows.hpp" + +namespace Star { + +size_t wcharLen(WCHAR const* s) { + size_t size = 0; + while (*s) { + ++size; + ++s; + } + return size; +} + +String utf16ToString(WCHAR const* s) { + if (!s) + return ""; + int sLen = wcharLen(s); + int utf8Len = WideCharToMultiByte(CP_UTF8, 0, s, sLen + 1, NULL, 0, NULL, NULL); + auto utf8Buffer = new char[utf8Len]; + WideCharToMultiByte(CP_UTF8, 0, s, sLen + 1, utf8Buffer, utf8Len, NULL, NULL); + auto result = String(utf8Buffer, utf8Len - 1); + delete[] utf8Buffer; + return result; +} + +unique_ptr<WCHAR[]> stringToUtf16(String const& s) { + int utf16Len = MultiByteToWideChar(CP_UTF8, 0, s.utf8Ptr(), s.utf8Size() + 1, NULL, 0); + unique_ptr<WCHAR[]> result; + result.reset(new WCHAR[utf16Len]); + MultiByteToWideChar(CP_UTF8, 0, s.utf8Ptr(), s.utf8Size() + 1, result.get(), utf16Len); + return result; +} + +} diff --git a/source/core/StarString_windows.hpp b/source/core/StarString_windows.hpp new file mode 100644 index 0000000..e681a9b --- /dev/null +++ b/source/core/StarString_windows.hpp @@ -0,0 +1,15 @@ +#ifndef STAR_STRING_WINDOWS_HPP +#define STAR_STRING_WINDOWS_HPP + +#include <windows.h> + +#include "StarString.hpp" + +namespace Star { + +String utf16ToString(WCHAR const* s); +unique_ptr<WCHAR[]> stringToUtf16(String const& s); + +} + +#endif diff --git a/source/core/StarStrongTypedef.hpp b/source/core/StarStrongTypedef.hpp new file mode 100644 index 0000000..f5cb5ac --- /dev/null +++ b/source/core/StarStrongTypedef.hpp @@ -0,0 +1,105 @@ +#ifndef STAR_STRONG_TYPEDEF_HPP +#define STAR_STRONG_TYPEDEF_HPP + +#include <type_traits> + +// Defines a new type that behaves nearly identical to 'parentType', with the +// added benefit that though the new type can be implicitly converted to the +// base type, it must be explicitly converted *from* the base type, and they +// are two distinct types in the type system. +#define strong_typedef(ParentType, NewType) \ + template <typename BaseType> \ + struct NewType##Wrapper : BaseType { \ + using BaseType::BaseType; \ + \ + NewType##Wrapper() : BaseType() {} \ + \ + NewType##Wrapper(NewType##Wrapper const& nt) : BaseType(nt) {} \ + \ + NewType##Wrapper(NewType##Wrapper&& nt) : BaseType(std::move(nt)) {} \ + \ + explicit NewType##Wrapper(BaseType const& bt) : BaseType(bt) {} \ + \ + explicit NewType##Wrapper(BaseType&& bt) : BaseType(std::move(bt)) {} \ + \ + NewType##Wrapper& operator=(NewType##Wrapper const& rhs) { \ + BaseType::operator=(rhs); \ + return *this; \ + } \ + \ + NewType##Wrapper& operator=(NewType##Wrapper&& rhs) { \ + BaseType::operator=(std::move(rhs)); \ + return *this; \ + } \ + \ + template <class Arg> \ + NewType##Wrapper& operator=(Arg&& other) { \ + static_assert(std::is_base_of<BaseType, typename std::decay<Arg>::type>::value == false \ + || std::is_same<NewType##Wrapper, typename std::decay<Arg>::type>::value, \ + "" #NewType " can not implicitly be assigned from " #ParentType "-derived classes or strong " #ParentType \ + " typedefs"); \ + \ + BaseType::operator=(std::forward<Arg>(other)); \ + return *this; \ + } \ + }; \ + typedef NewType##Wrapper<ParentType> NewType + +// Version of strong_typedef for builtin types. +#define strong_typedef_builtin(Type, NewType) \ + struct NewType { \ + Type t; \ + \ + explicit NewType(const Type t_) \ + : t(t_){}; \ + \ + NewType() \ + : t(Type()) {} \ + \ + NewType(const NewType& t_) \ + : t(t_.t) {} \ + \ + NewType& operator=(const NewType& rhs) { \ + t = rhs.t; \ + return *this; \ + } \ + \ + NewType& operator=(Type const& rhs) { \ + t = rhs; \ + return *this; \ + } \ + \ + operator const Type&() const { \ + return t; \ + } \ + \ + operator Type&() { \ + return t; \ + } \ + \ + bool operator==(NewType const& rhs) const { \ + return t == rhs.t; \ + } \ + \ + bool operator!=(NewType const& rhs) const { \ + return t != rhs.t; \ + } \ + \ + bool operator<(NewType const& rhs) const { \ + return t < rhs.t; \ + } \ + \ + bool operator>(NewType const& rhs) const { \ + return t > rhs.t; \ + } \ + \ + bool operator<=(NewType const& rhs) const { \ + return t <= rhs.t; \ + } \ + \ + bool operator>=(NewType const& rhs) const { \ + return t >= rhs.t; \ + } \ + } + +#endif diff --git a/source/core/StarTcp.cpp b/source/core/StarTcp.cpp new file mode 100644 index 0000000..dc64b98 --- /dev/null +++ b/source/core/StarTcp.cpp @@ -0,0 +1,221 @@ +#include "StarTcp.hpp" +#include "StarLogging.hpp" +#include "StarNetImpl.hpp" + +namespace Star { + +TcpSocketPtr TcpSocket::connectTo(HostAddressWithPort const& addressWithPort) { + auto socket = TcpSocketPtr(new TcpSocket(addressWithPort.address().mode())); + socket->connect(addressWithPort); + return socket; +} + +TcpSocketPtr TcpSocket::listen(HostAddressWithPort const& addressWithPort) { + auto socket = TcpSocketPtr(new TcpSocket(addressWithPort.address().mode())); + socket->bind(addressWithPort); + ((Socket&)(*socket)).listen(32); + return socket; +} + +TcpSocketPtr TcpSocket::accept() { + ReadLocker locker(m_mutex); + + if (m_socketMode != SocketMode::Bound) + throw SocketClosedException("TcpSocket not bound in TcpSocket::accept"); + + struct sockaddr_storage sockAddr; + socklen_t sockAddrLen = sizeof(sockAddr); + + auto socketDesc = ::accept(m_impl->socketDesc, (struct sockaddr*)&sockAddr, &sockAddrLen); + + if (invalidSocketDescriptor(socketDesc)) { + if (netErrorInterrupt()) + return {}; + throw NetworkException(strf("Cannot accept connection: %s", netErrorString())); + } + + auto socketImpl = make_shared<SocketImpl>(); + socketImpl->socketDesc = socketDesc; + +#if defined STAR_SYSTEM_MACOS || defined STAR_SYSTEM_FREEBSD + // Don't generate sigpipe + int set = 1; + socketImpl->setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); +#endif + + TcpSocketPtr sockPtr(new TcpSocket(m_localAddress.address().mode(), socketImpl)); + + sockPtr->m_localAddress = m_localAddress; + setAddressFromNative(sockPtr->m_remoteAddress, m_localAddress.address().mode(), &sockAddr); + Logger::debug("accept from %s (%d)", sockPtr->m_remoteAddress, sockPtr->m_impl->socketDesc); + + return sockPtr; +} + +void TcpSocket::setNoDelay(bool noDelay) { + ReadLocker locker(m_mutex); + checkOpen("TcpSocket::setNoDelay"); + + int flag = noDelay ? 1 : 0; + m_impl->setSockOpt(IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(flag)); +} + +size_t TcpSocket::receive(char* data, size_t size) { + ReadLocker locker(m_mutex); + checkOpen("TcpSocket::receive"); + + if (m_socketMode == SocketMode::Closed) + throw SocketClosedException("TcpSocket not open in TcpSocket::receive"); + + int flags = 0; +#ifdef STAR_SYSTEM_LINUX + // Don't generate sigpipe + flags |= MSG_NOSIGNAL; +#endif + + auto r = ::recv(m_impl->socketDesc, data, size, flags); + if (r < 0) { + if (m_socketMode == SocketMode::Shutdown) { + throw SocketClosedException("Connection closed"); + } else if (netErrorConnectionReset()) { + doShutdown(); + throw SocketClosedException("Connection reset"); + } else if (netErrorInterrupt()) { + r = 0; + } else { + throw NetworkException(strf("tcp recv error: %s", netErrorString())); + } + } + + return r; +} + +size_t TcpSocket::send(char const* data, size_t size) { + ReadLocker locker(m_mutex); + checkOpen("TcpSocket::send"); + + if (m_socketMode == SocketMode::Closed) + throw SocketClosedException("TcpSocket not open in TcpSocket::send"); + + int flags = 0; +#ifdef STAR_SYSTEM_LINUX + // Don't generate sigpipe + flags |= MSG_NOSIGNAL; +#endif + + auto w = ::send(m_impl->socketDesc, data, size, flags); + if (w < 0) { + if (m_socketMode == SocketMode::Shutdown) { + throw SocketClosedException("Connection closed"); + } else if (netErrorConnectionReset()) { + doShutdown(); + throw SocketClosedException("Connection reset"); + } else if (netErrorInterrupt()) { + w = 0; + } else { + throw NetworkException(strf("tcp send error: %s", netErrorString())); + } + } + + return w; +} + +HostAddressWithPort TcpSocket::localAddress() const { + ReadLocker locker(m_mutex); + return m_localAddress; +} + +HostAddressWithPort TcpSocket::remoteAddress() const { + ReadLocker locker(m_mutex); + return m_remoteAddress; +} + +TcpSocket::TcpSocket(NetworkMode networkMode) : Socket(SocketType::Tcp, networkMode) {} + +TcpSocket::TcpSocket(NetworkMode networkMode, SocketImplPtr impl) : Socket(networkMode, impl, SocketMode::Connected) {} + +void TcpSocket::connect(HostAddressWithPort const& addressWithPort) { + WriteLocker locker(m_mutex); + checkOpen("TcpSocket::connect"); + + if (m_networkMode != addressWithPort.address().mode()) + throw NetworkException("Socket address type mismatch between address and socket."); + + struct sockaddr_storage sockAddr; + socklen_t sockAddrLen; + setNativeFromAddress(addressWithPort, &sockAddr, &sockAddrLen); + if (::connect(m_impl->socketDesc, (struct sockaddr*)&sockAddr, sockAddrLen) < 0) + throw NetworkException(strf("cannot connect to %s: %s", addressWithPort, netErrorString())); + +#if defined STAR_SYSTEM_MACOS || defined STAR_SYSTEM_FREEBSD + // Don't generate sigpipe + int set = 1; + m_impl->setSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(set)); +#endif + + m_socketMode = SocketMode::Connected; + m_remoteAddress = addressWithPort; +} + +TcpServer::TcpServer(HostAddressWithPort const& address) : m_hostAddress(address) { + m_hostAddress = address; + m_listenSocket = TcpSocket::listen(address); + m_listenSocket->setNonBlocking(true); + Logger::debug("TcpServer listening on: %s", address); +} + +TcpServer::TcpServer(uint16_t port) : TcpServer(HostAddressWithPort("*", port)) {} + +TcpServer::~TcpServer() { + stop(); +} + +void TcpServer::stop() { + m_listenSocket->shutdown(); + m_callbackThread.finish(); + m_listenSocket->close(); +} + +bool TcpServer::isListening() const { + return m_listenSocket->isActive(); +} + +TcpSocketPtr TcpServer::accept(unsigned timeout) { + MutexLocker locker(m_mutex); + Socket::poll({{m_listenSocket, {true, false}}}, timeout); + try { + return m_listenSocket->accept(); + } catch (SocketClosedException const&) { + return {}; + } +} + +void TcpServer::setAcceptCallback(AcceptCallback callback, unsigned timeout) { + MutexLocker locker(m_mutex); + m_callback = callback; + if (m_listenSocket->isActive() && !m_callbackThread) { + m_callbackThread = Thread::invoke("TcpServer::acceptCallback", [this, timeout]() { + try { + while (true) { + TcpSocketPtr conn; + try { + conn = accept(timeout); + } catch (NetworkException const& e) { + Logger::error("TcpServer caught exception accepting connection %s", outputException(e, false)); + } + + if (conn) + m_callback(conn); + + if (!m_listenSocket->isActive()) + break; + } + } catch (std::exception const& e) { + Logger::error("TcpServer will close, listener thread caught exception: %s", outputException(e, true)); + m_listenSocket->close(); + } + }); + } +} + +} diff --git a/source/core/StarTcp.hpp b/source/core/StarTcp.hpp new file mode 100644 index 0000000..362788e --- /dev/null +++ b/source/core/StarTcp.hpp @@ -0,0 +1,75 @@ +#ifndef STAR_TCP_HPP +#define STAR_TCP_HPP + +#include "StarIODevice.hpp" +#include "StarSocket.hpp" +#include "StarThread.hpp" + +namespace Star { + +STAR_CLASS(TcpSocket); +STAR_CLASS(TcpServer); + +class TcpSocket : public Socket { +public: + static TcpSocketPtr connectTo(HostAddressWithPort const& address); + static TcpSocketPtr listen(HostAddressWithPort const& address); + + TcpSocketPtr accept(); + + // Must be called after connect. Sets TCP_NODELAY option. + void setNoDelay(bool noDelay); + + size_t receive(char* data, size_t len); + size_t send(char const* data, size_t len); + + HostAddressWithPort localAddress() const; + HostAddressWithPort remoteAddress() const; + +private: + TcpSocket(NetworkMode networkMode); + TcpSocket(NetworkMode networkMode, SocketImplPtr impl); + + void connect(HostAddressWithPort const& address); + + HostAddressWithPort m_remoteAddress; +}; + +// Simple class to listen for and open TcpSocket instances. +class TcpServer { +public: + typedef function<void(TcpSocketPtr socket)> AcceptCallback; + + TcpServer(HostAddressWithPort const& address); + // Listens to all interfaces. + TcpServer(uint16_t port); + ~TcpServer(); + + void stop(); + bool isListening() const; + + // Blocks until next connection available for the given timeout. Throws + // ServerClosed if close() is called. Cannot be called if AcceptCallback is + // set. + TcpSocketPtr accept(unsigned timeout); + + // Rather than calling and blocking on accept(), if an AcceptCallback is set + // here, it will be called whenever a new connection is available. + // Exceptions thrown from the callback function will be caught and logged, + // and will cause the server to close. The timeout here is the timeout that + // is passed to accept in the loop, the longer the timeout the slower it will + // shutdown on a call to close. + void setAcceptCallback(AcceptCallback callback, unsigned timeout = 20); + +private: + mutable Mutex m_mutex; + + AcceptCallback m_callback; + ThreadFunction<void> m_callbackThread; + HostAddressWithPort m_hostAddress; + TcpSocketPtr m_listenSocket; +}; + +} + +#endif diff --git a/source/core/StarThread.cpp b/source/core/StarThread.cpp new file mode 100644 index 0000000..e76d25c --- /dev/null +++ b/source/core/StarThread.cpp @@ -0,0 +1,121 @@ +#include "StarThread.hpp" +#include "StarFormat.hpp" + +namespace Star { + +ReadersWriterMutex::ReadersWriterMutex() + : m_readers(), m_writers(), m_readWaiters(), m_writeWaiters() {} + +void ReadersWriterMutex::readLock() { + MutexLocker locker(m_mutex); + if (m_writers || m_writeWaiters) { + m_readWaiters++; + while (m_writers || m_writeWaiters) + m_readCond.wait(m_mutex); + m_readWaiters--; + } + m_readers++; +} + +bool ReadersWriterMutex::tryReadLock() { + MutexLocker locker(m_mutex); + if (m_writers || m_writeWaiters) + return false; + m_readers++; + return true; +} + +void ReadersWriterMutex::readUnlock() { + MutexLocker locker(m_mutex); + m_readers--; + if (m_writeWaiters) + m_writeCond.signal(); +} + +void ReadersWriterMutex::writeLock() { + MutexLocker locker(m_mutex); + if (m_readers || m_writers) { + m_writeWaiters++; + while (m_readers || m_writers) + m_writeCond.wait(m_mutex); + m_writeWaiters--; + } + m_writers = 1; +} + +bool ReadersWriterMutex::tryWriteLock() { + MutexLocker locker(m_mutex); + if (m_readers || m_writers) + return false; + m_writers = 1; + return true; +} + +void ReadersWriterMutex::writeUnlock() { + MutexLocker locker(m_mutex); + m_writers = 0; + if (m_writeWaiters) + m_writeCond.signal(); + else if (m_readWaiters) + m_readCond.broadcast(); +} + +ReadLocker::ReadLocker(ReadersWriterMutex& rwlock, bool startLocked) : m_lock(rwlock), m_locked(false) { + if (startLocked) + lock(); +} + +ReadLocker::~ReadLocker() { + unlock(); +} + +void ReadLocker::unlock() { + if (m_locked) + m_lock.readUnlock(); + m_locked = false; +} + +void ReadLocker::lock() { + if (!m_locked) + m_lock.readLock(); + m_locked = true; +} + +bool ReadLocker::tryLock() { + if (!m_locked) { + m_locked = m_lock.tryReadLock(); + return m_locked; + } + return true; +} + +WriteLocker::WriteLocker(ReadersWriterMutex& rwlock, bool startLocked) : m_lock(rwlock), m_locked(false) { + if (startLocked) + lock(); +} + +WriteLocker::~WriteLocker() { + unlock(); +} + +void WriteLocker::unlock() { + if (m_locked) + m_lock.writeUnlock(); + m_locked = false; +} + +void WriteLocker::lock() { + if (!m_locked) + m_lock.writeLock(); + m_locked = true; +} + +bool WriteLocker::tryLock() { + if (!m_locked) { + m_locked = m_lock.tryWriteLock(); + return m_locked; + } + return true; +} + +} diff --git a/source/core/StarThread.hpp b/source/core/StarThread.hpp new file mode 100644 index 0000000..5d2c69f --- /dev/null +++ b/source/core/StarThread.hpp @@ -0,0 +1,425 @@ +#ifndef STAR_THREAD_HPP +#define STAR_THREAD_HPP + +#include "StarException.hpp" +#include "StarString.hpp" + +namespace Star { + +STAR_STRUCT(ThreadImpl); +STAR_STRUCT(ThreadFunctionImpl); +STAR_STRUCT(MutexImpl); +STAR_STRUCT(ConditionVariableImpl); +STAR_STRUCT(RecursiveMutexImpl); + +template <typename Return> +class ThreadFunction; + +class Thread { +public: + // Implementations of this method should sleep for at least the given amount + // of time, but may sleep for longer due to scheduling. + static void sleep(unsigned millis); + + // Sleep a more precise amount of time, but uses more resources to do so. + // Should be less likely to sleep much longer than the given amount of time. + static void sleepPrecise(unsigned millis); + + // Yield this thread, offering the opportunity to reschedule. + static void yield(); + + static unsigned numberOfProcessors(); + + template <typename Function, typename... Args> + static ThreadFunction<decltype(std::declval<Function>()(std::declval<Args>()...))> invoke(String const& name, Function&& f, Args&&... args); + + Thread(String const& name); + Thread(Thread&&); + // Will not automatically join! ALL implementations of this class MUST call + // join() in their most derived constructors, or not rely on the destructor + // joining. + virtual ~Thread(); + + Thread& operator=(Thread&&); + + // Start a thread that is currently in the joined state. Returns true if the + // thread was joined and is now started, false if the thread was not joined. + bool start(); + + // Wait for a thread to finish and re-join with the thread, on completion + // isJoined() will be false. Returns true if the thread was joinable, and is + // now joined, false if the thread was already joined. + bool join(); + + // Returns false when this thread been started without being joined. This is + // subtlely different than "!isRunning()", in that the thread could have + // completed its work, but a thread *must* be joined before being restarted. + bool isJoined() const; + + // Returns false before start() has been called, true immediately after + // start() has been called, and false once the run() method returns. + bool isRunning() const; + + String name(); + +protected: + virtual void run() = 0; + +private: + unique_ptr<ThreadImpl> m_impl; +}; + +// Wraps a function call and calls in another thread, very nice lightweight +// one-shot alternative to deriving from Thread. Handles exceptions in a +// different way from Thread, instead of logging the exception, the exception +// is forwarded and re-thrown during the call to finish(). +template <> +class ThreadFunction<void> { +public: + ThreadFunction(); + ThreadFunction(ThreadFunction&&); + + // Automatically starts the given function, ThreadFunction can also be + // constructed with Thread::invoke, which is a shorthand. + ThreadFunction(function<void()> function, String const& name); + + // Automatically calls finish, though BEWARE that often times this is quite + // dangerous, and this is here mostly as a fallback. The natural destructor + // order for members of a class is often wrong, and if the function throws, + // since this destructor calls finish it will throw. + ~ThreadFunction(); + + ThreadFunction& operator=(ThreadFunction&&); + + // Waits on function finish if function is assigned and started, otherwise + // does nothing. If the function threw an exception, it will be re-thrown + // here (on the first call to finish() only). + void finish(); + + // Returns whether the ThreadFunction::finish method been called and the + // ThreadFunction has stopped. Also returns true when the ThreadFunction has + // been default constructed. + bool isFinished() const; + // Returns false if the thread function has stopped running, whether or not + // finish() has been called. + bool isRunning() const; + + // Equivalent to !isFinished() + explicit operator bool() const; + + String name(); + +private: + unique_ptr<ThreadFunctionImpl> m_impl; +}; + +template <typename Return> +class ThreadFunction { +public: + ThreadFunction(); + ThreadFunction(ThreadFunction&&); + ThreadFunction(function<Return()> function, String const& name); + + ~ThreadFunction(); + + ThreadFunction& operator=(ThreadFunction&&); + + // Finishes the thread, moving and returning the final value of the function. + // If the function threw an exception, finish() will rethrow that exception. + // May only be called once, otherwise will throw InvalidMaybeAccessException. + Return finish(); + + bool isFinished() const; + bool isRunning() const; + + explicit operator bool() const; + + String name(); + +private: + ThreadFunction<void> m_function; + shared_ptr<Maybe<Return>> m_return; +}; + +// *Non* recursive mutex lock, for use with ConditionVariable +class Mutex { +public: + Mutex(); + Mutex(Mutex&&); + ~Mutex(); + + Mutex& operator=(Mutex&&); + + void lock(); + + // Attempt to acquire the mutex without blocking. + bool tryLock(); + + void unlock(); + +private: + friend struct ConditionVariableImpl; + unique_ptr<MutexImpl> m_impl; +}; + +class ConditionVariable { +public: + ConditionVariable(); + ConditionVariable(ConditionVariable&&); + ~ConditionVariable(); + + ConditionVariable& operator=(ConditionVariable&&); + + // Atomically unlocks the mutex argument and waits on the condition. On + // acquiring the condition, atomically returns and re-locks the mutex. Must + // lock the mutex before calling. If millis is given, waits for a maximum of + // the given milliseconds only. + void wait(Mutex& mutex, Maybe<unsigned> millis = {}); + + // Wake one waiting thread. The calling thread for is allowed to either hold + // or not hold the mutex that the threads waiting on the condition are using, + // both will work and result in slightly different scheduling. + void signal(); + + // Wake all threads, policy for holding the mutex is the same for signal(). + void broadcast(); + +private: + unique_ptr<ConditionVariableImpl> m_impl; +}; + +// Recursive mutex lock. lock() may be called many times freely by the same +// thread, but unlock() must be called an equal number of times to unlock it. +class RecursiveMutex { +public: + RecursiveMutex(); + RecursiveMutex(RecursiveMutex&&); + ~RecursiveMutex(); + + RecursiveMutex& operator=(RecursiveMutex&&); + + void lock(); + + // Attempt to acquire the mutex without blocking. + bool tryLock(); + + void unlock(); + +private: + unique_ptr<RecursiveMutexImpl> m_impl; +}; + +// RAII for mutexes. Locking and unlocking are always safe, MLocker will never +// attempt to lock the held mutex more than once, or unlock more than once, and +// destruction will always unlock the mutex *iff* it is actually locked. +// (Locked here refers to one specific MLocker *itself* locking the mutex, not +// whether the mutex is locked *at all*, so it is sensible to use with +// RecursiveMutex) +template <typename MutexType> +class MLocker { +public: + // Pass false to lock to start unlocked + MLocker(MutexType& ref, bool lock = true); + ~MLocker(); + + MLocker(MLocker const&) = delete; + MLocker& operator=(MLocker const&) = delete; + + MutexType& mutex(); + + void unlock(); + void lock(); + bool tryLock(); + +private: + MutexType& m_mutex; + bool m_locked; +}; +typedef MLocker<Mutex> MutexLocker; +typedef MLocker<RecursiveMutex> RecursiveMutexLocker; + +class ReadersWriterMutex { +public: + ReadersWriterMutex(); + + void readLock(); + bool tryReadLock(); + void readUnlock(); + + void writeLock(); + bool tryWriteLock(); + void writeUnlock(); + +private: + Mutex m_mutex; + ConditionVariable m_readCond; + ConditionVariable m_writeCond; + unsigned m_readers; + unsigned m_writers; + unsigned m_readWaiters; + unsigned m_writeWaiters; +}; + +class ReadLocker { +public: + ReadLocker(ReadersWriterMutex& rwlock, bool startLocked = true); + ~ReadLocker(); + + ReadLocker(ReadLocker const&) = delete; + ReadLocker& operator=(ReadLocker const&) = delete; + + void unlock(); + void lock(); + bool tryLock(); + +private: + ReadersWriterMutex& m_lock; + bool m_locked; +}; + +class WriteLocker { +public: + WriteLocker(ReadersWriterMutex& rwlock, bool startLocked = true); + ~WriteLocker(); + + WriteLocker(WriteLocker const&) = delete; + WriteLocker& operator=(WriteLocker const&) = delete; + + void unlock(); + void lock(); + bool tryLock(); + +private: + ReadersWriterMutex& m_lock; + bool m_locked; +}; + +class SpinLock { +public: + SpinLock(); + + void lock(); + bool tryLock(); + void unlock(); + +private: + atomic_flag m_lock; +}; +typedef MLocker<SpinLock> SpinLocker; + +template <typename MutexType> +MLocker<MutexType>::MLocker(MutexType& ref, bool l) + : m_mutex(ref), m_locked(false) { + if (l) + lock(); +} + +template <typename MutexType> +MLocker<MutexType>::~MLocker() { + unlock(); +} + +template <typename MutexType> +MutexType& MLocker<MutexType>::mutex() { + return m_mutex; +} + +template <typename MutexType> +void MLocker<MutexType>::unlock() { + if (m_locked) { + m_mutex.unlock(); + m_locked = false; + } +} + +template <typename MutexType> +void MLocker<MutexType>::lock() { + if (!m_locked) { + m_mutex.lock(); + m_locked = true; + } +} + +template <typename MutexType> +bool MLocker<MutexType>::tryLock() { + if (!m_locked) { + if (m_mutex.tryLock()) + m_locked = true; + } + + return m_locked; +} + +template <typename Function, typename... Args> +ThreadFunction<decltype(std::declval<Function>()(std::declval<Args>()...))> Thread::invoke(String const& name, Function&& f, Args&&... args) { + return {bind(forward<Function>(f), forward<Args>(args)...), name}; +} + +template <typename Return> +ThreadFunction<Return>::ThreadFunction() {} + +template <typename Return> +ThreadFunction<Return>::ThreadFunction(ThreadFunction&&) = default; + +template <typename Return> +ThreadFunction<Return>::ThreadFunction(function<Return()> function, String const& name) { + m_return = make_shared<Maybe<Return>>(); + m_function = ThreadFunction<void>([function = move(function), retValue = m_return]() { + *retValue = function(); + }, name); +} + +template <typename Return> +ThreadFunction<Return>::~ThreadFunction() { + m_function.finish(); +} + +template <typename Return> +ThreadFunction<Return>& ThreadFunction<Return>::operator=(ThreadFunction&&) = default; + +template <typename Return> +Return ThreadFunction<Return>::finish() { + m_function.finish(); + return m_return->take(); +} + +template <typename Return> +bool ThreadFunction<Return>::isFinished() const { + return m_function.isFinished(); +} + +template <typename Return> +bool ThreadFunction<Return>::isRunning() const { + return m_function.isRunning(); +} + +template <typename Return> +ThreadFunction<Return>::operator bool() const { + return !isFinished(); +} + +template <typename Return> +String ThreadFunction<Return>::name() { + return m_function.name(); +} + +inline SpinLock::SpinLock() { + m_lock.clear(); +} + +inline void SpinLock::lock() { + while (m_lock.test_and_set(std::memory_order_acquire)) + ; +} + +inline void SpinLock::unlock() { + m_lock.clear(std::memory_order_release); +} + +inline bool SpinLock::tryLock() { + return !m_lock.test_and_set(std::memory_order_acquire); +} + +} + +#endif diff --git a/source/core/StarThread_unix.cpp b/source/core/StarThread_unix.cpp new file mode 100644 index 0000000..ba635f9 --- /dev/null +++ b/source/core/StarThread_unix.cpp @@ -0,0 +1,390 @@ +#include "StarThread.hpp" +#include "StarTime.hpp" +#include "StarLogging.hpp" + +#include <limits.h> +#include <libgen.h> +#include <stdlib.h> +#include <string.h> +#include <unistd.h> +#include <dlfcn.h> +#include <dirent.h> +#include <pthread.h> +#ifdef STAR_SYSTEM_FREEBSD +#include <pthread_np.h> +#endif +#include <sys/time.h> +#include <errno.h> + +#ifdef MAXCOMLEN +#define MAX_THREAD_NAMELEN MAXCOMLEN +#else +#define MAX_THREAD_NAMELEN 16 +#endif + +namespace Star { + +struct ThreadImpl { + static void* runThread(void* data) { + ThreadImpl* ptr = static_cast<ThreadImpl*>(data); + try { +#ifdef STAR_SYSTEM_MACOS + // ensure the name is under the max allowed + char tname[MAX_THREAD_NAMELEN]; + snprintf(tname, sizeof(tname), "%s", ptr->name.utf8Ptr()); + + pthread_setname_np(tname); +#endif + ptr->function(); + } catch (std::exception const& e) { + if (ptr->name.empty()) + Logger::error("Exception caught in Thread: %s", outputException(e, true)); + else + Logger::error("Exception caught in Thread %s: %s", ptr->name, outputException(e, true)); + } catch (...) { + if (ptr->name.empty()) + Logger::error("Unknown exception caught in Thread"); + else + Logger::error("Unknown exception caught in Thread %s", ptr->name); + } + ptr->stopped = true; + return nullptr; + } + + ThreadImpl(std::function<void()> function, String name) + : function(std::move(function)), name(std::move(name)), stopped(true), joined(true) {} + + bool start() { + MutexLocker mutexLocker(mutex); + if (!joined) + return false; + + stopped = false; + joined = false; + int ret = pthread_create(&pthread, NULL, &runThread, (void*)this); + if (ret != 0) { + stopped = true; + joined = true; + throw StarException(strf("Failed to create thread, error %s", ret)); + } + + // ensure the name is under the max allowed + char tname[MAX_THREAD_NAMELEN]; + snprintf(tname, sizeof(tname), "%s", name.utf8Ptr()); + +#ifdef STAR_SYSTEM_FREEBSD + pthread_set_name_np(pthread, tname); +#elif not defined STAR_SYSTEM_MACOS + pthread_setname_np(pthread, tname); +#endif + return true; + } + + bool join() { + MutexLocker mutexLocker(mutex); + if (joined) + return false; + int ret = pthread_join(pthread, NULL); + if (ret != 0) + throw StarException(strf("Failed to join thread, error %s", ret)); + joined = true; + return true; + } + + std::function<void()> function; + String name; + pthread_t pthread; + atomic<bool> stopped; + bool joined; + Mutex mutex; +}; + +struct ThreadFunctionImpl : ThreadImpl { + ThreadFunctionImpl(std::function<void()> function, String name) + : ThreadImpl(wrapFunction(move(function)), move(name)) {} + + std::function<void()> wrapFunction(std::function<void()> function) { + return [function = move(function), this]() { + try { + function(); + } catch (...) { + exception = std::current_exception(); + } + }; + } + + std::exception_ptr exception; +}; + +struct MutexImpl { + MutexImpl() { + pthread_mutexattr_t mutexattr; + pthread_mutexattr_init(&mutexattr); + + pthread_mutex_init(&mutex, &mutexattr); + + pthread_mutexattr_destroy(&mutexattr); + } + + ~MutexImpl() { + pthread_mutex_destroy(&mutex); + } + + void lock() { + pthread_mutex_lock(&mutex); + } + + void unlock() { + pthread_mutex_unlock(&mutex); + } + + bool tryLock() { + if (pthread_mutex_trylock(&mutex) == 0) + return true; + else + return false; + } + + pthread_mutex_t mutex; +}; + +struct ConditionVariableImpl { + ConditionVariableImpl() { + pthread_cond_init(&condition, NULL); + } + + ~ConditionVariableImpl() { + pthread_cond_destroy(&condition); + } + + void wait(Mutex& mutex) { + pthread_cond_wait(&condition, &mutex.m_impl->mutex); + } + + void wait(Mutex& mutex, unsigned millis) { + int64_t time = Time::millisecondsSinceEpoch() + millis; + + timespec ts; + ts.tv_sec = time / 1000; + ts.tv_nsec = (time % 1000) * 1000000; + + pthread_cond_timedwait(&condition, &mutex.m_impl->mutex, &ts); + } + + void signal() { + pthread_cond_signal(&condition); + } + + void broadcast() { + pthread_cond_broadcast(&condition); + } + + pthread_cond_t condition; +}; + +struct RecursiveMutexImpl { + RecursiveMutexImpl() { + pthread_mutexattr_t mutexattr; + pthread_mutexattr_init(&mutexattr); + + pthread_mutexattr_settype(&mutexattr, PTHREAD_MUTEX_RECURSIVE); + + pthread_mutex_init(&mutex, &mutexattr); + + pthread_mutexattr_destroy(&mutexattr); + } + + ~RecursiveMutexImpl() { + pthread_mutex_destroy(&mutex); + } + + void lock() { + pthread_mutex_lock(&mutex); + } + + void unlock() { + pthread_mutex_unlock(&mutex); + } + + bool tryLock() { + if (pthread_mutex_trylock(&mutex) == 0) + return true; + else + return false; + } + + pthread_mutex_t mutex; +}; + +void Thread::sleepPrecise(unsigned msecs) { + int64_t now = Time::monotonicMilliseconds(); + int64_t deadline = now + msecs; + + while (deadline - now > 10) { + usleep((deadline - now - 10) * 1000); + now = Time::monotonicMilliseconds(); + } + + while (deadline > now) { + usleep((deadline - now) * 500); + now = Time::monotonicMilliseconds(); + } +} + +void Thread::sleep(unsigned msecs) { + usleep(msecs * 1000); +} + +void Thread::yield() { + sched_yield(); +} + +unsigned Thread::numberOfProcessors() { + long nprocs = sysconf(_SC_NPROCESSORS_ONLN); + if (nprocs < 1) + throw StarException(strf("Could not determine number of CPUs online: %s\n", strerror(errno))); + return nprocs; +} + +Thread::Thread(String const& name) { + m_impl.reset(new ThreadImpl([this]() { + run(); + }, name)); +} + +Thread::Thread(Thread&&) = default; + +Thread::~Thread() {} + +Thread& Thread::operator=(Thread&&) = default; + +bool Thread::start() { + return m_impl->start(); +} + +bool Thread::join() { + return m_impl->join(); +} + +String Thread::name() { + return m_impl->name; +} + +bool Thread::isJoined() const { + return m_impl->joined; +} + +bool Thread::isRunning() const { + return !m_impl->stopped; +} + +ThreadFunction<void>::ThreadFunction() {} + +ThreadFunction<void>::ThreadFunction(ThreadFunction&&) = default; + +ThreadFunction<void>::ThreadFunction(function<void()> function, String const& name) { + m_impl.reset(new ThreadFunctionImpl(move(function), name)); + m_impl->start(); +} + +ThreadFunction<void>::~ThreadFunction() { + finish(); +} + +ThreadFunction<void>& ThreadFunction<void>::operator=(ThreadFunction&&) = default; + +void ThreadFunction<void>::finish() { + if (m_impl) { + m_impl->join(); + + if (m_impl->exception) + std::rethrow_exception(take(m_impl->exception)); + } +} + +bool ThreadFunction<void>::isFinished() const { + return !m_impl || m_impl->joined; +} + +bool ThreadFunction<void>::isRunning() const { + return m_impl && !m_impl->stopped; +} + +ThreadFunction<void>::operator bool() const { + return !isFinished(); +} + +String ThreadFunction<void>::name() { + if (m_impl) + return m_impl->name; + else + return ""; +} + +Mutex::Mutex() + : m_impl(new MutexImpl()) {} + +Mutex::Mutex(Mutex&&) = default; + +Mutex::~Mutex() {} + +Mutex& Mutex::operator=(Mutex&&) = default; + +void Mutex::lock() { + m_impl->lock(); +} + +bool Mutex::tryLock() { + return m_impl->tryLock(); +} + +void Mutex::unlock() { + m_impl->unlock(); +} + +ConditionVariable::ConditionVariable() + : m_impl(new ConditionVariableImpl()) {} + +ConditionVariable::ConditionVariable(ConditionVariable&&) = default; + +ConditionVariable::~ConditionVariable() {} + +ConditionVariable& ConditionVariable::operator=(ConditionVariable&&) = default; + +void ConditionVariable::wait(Mutex& mutex, Maybe<unsigned> millis) { + if (millis) + m_impl->wait(mutex, *millis); + else + m_impl->wait(mutex); +} + +void ConditionVariable::signal() { + m_impl->signal(); +} + +void ConditionVariable::broadcast() { + m_impl->broadcast(); +} + +RecursiveMutex::RecursiveMutex() + : m_impl(new RecursiveMutexImpl()) {} + +RecursiveMutex::RecursiveMutex(RecursiveMutex&&) = default; + +RecursiveMutex::~RecursiveMutex() {} + +RecursiveMutex& RecursiveMutex::operator=(RecursiveMutex&&) = default; + +void RecursiveMutex::lock() { + m_impl->lock(); +} + +bool RecursiveMutex::tryLock() { + return m_impl->tryLock(); +} + +void RecursiveMutex::unlock() { + m_impl->unlock(); +} + +} diff --git a/source/core/StarThread_windows.cpp b/source/core/StarThread_windows.cpp new file mode 100644 index 0000000..86ebbf5 --- /dev/null +++ b/source/core/StarThread_windows.cpp @@ -0,0 +1,544 @@ +#include "StarThread.hpp" +#include "StarTime.hpp" +#include "StarLogging.hpp" +#include "StarDynamicLib.hpp" + +#include <windows.h> +#include <stdio.h> +#include <process.h> +#include <locale.h> + +namespace Star { + +// This is the CONDITIONAL_VARIABLE typedef for using Window's native +// conditional variables on kernels 6.0+. +// MinGW does not currently have this typedef. +typedef struct pthread_cond_t { void* ptr; } CONDITIONAL_VARIABLE; + +// Static thread initialization +struct ThreadSupport { + ThreadSupport() { + HMODULE kernel_dll = GetModuleHandle(TEXT("kernel32.dll")); + initializeConditionVariable = (InitializeConditionVariablePtr)GetProcAddress(kernel_dll, "InitializeConditionVariable"); + wakeAllConditionVariable = (WakeAllConditionVariablePtr)GetProcAddress(kernel_dll, "WakeAllConditionVariable"); + wakeConditionVariable = (WakeConditionVariablePtr)GetProcAddress(kernel_dll, "WakeConditionVariable"); + sleepConditionVariableCS = (SleepConditionVariableCSPtr)GetProcAddress(kernel_dll, "SleepConditionVariableCS"); + + nativeConditionVariables = initializeConditionVariable && wakeAllConditionVariable && wakeConditionVariable && sleepConditionVariableCS; + } + + typedef void(WINAPI* InitializeConditionVariablePtr)(CONDITIONAL_VARIABLE* cond); + typedef void(WINAPI* WakeAllConditionVariablePtr)(CONDITIONAL_VARIABLE* cond); + typedef void(WINAPI* WakeConditionVariablePtr)(CONDITIONAL_VARIABLE* cond); + typedef BOOL(WINAPI* SleepConditionVariableCSPtr)(CONDITIONAL_VARIABLE* cond, CRITICAL_SECTION* mutex, DWORD milliseconds); + + // function pointers to conditional variable API on windows 6.0+ kernels + InitializeConditionVariablePtr initializeConditionVariable; + WakeAllConditionVariablePtr wakeAllConditionVariable; + WakeConditionVariablePtr wakeConditionVariable; + SleepConditionVariableCSPtr sleepConditionVariableCS; + + bool nativeConditionVariables; +}; +static ThreadSupport g_threadSupport; + +struct ThreadImpl { + static DWORD WINAPI runThread(void* data) { + ThreadImpl* ptr = static_cast<ThreadImpl*>(data); + try { + ptr->function(); + } catch (std::exception const& e) { + if (ptr->name.empty()) + Logger::error("Exception caught in Thread: %s", outputException(e, true)); + else + Logger::error("Exception caught in Thread %s: %s", ptr->name, outputException(e, true)); + } catch (...) { + if (ptr->name.empty()) + Logger::error("Unknown exception caught in Thread"); + else + Logger::error("Unknown exception caught in Thread %s", ptr->name); + } + ptr->stopped = true; + return 0; + } + + ThreadImpl(std::function<void()> function, String name) + : function(std::move(function)), name(std::move(name)), thread(INVALID_HANDLE_VALUE), stopped(true) {} + + bool start() { + MutexLocker mutexLocker(mutex); + if (thread != INVALID_HANDLE_VALUE) + return false; + + stopped = false; + if (thread == INVALID_HANDLE_VALUE) + thread = CreateThread(NULL, 0, runThread, (void*)this, 0, NULL); + if (thread == NULL) + thread = INVALID_HANDLE_VALUE; + if (thread == INVALID_HANDLE_VALUE) { + stopped = true; + throw StarException("Failed to create thread"); + } + return true; + } + + bool join() { + MutexLocker mutexLocker(mutex); + if (thread == INVALID_HANDLE_VALUE) + return false; + WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + thread = INVALID_HANDLE_VALUE; + return true; + } + + std::function<void()> function; + String name; + HANDLE thread; + atomic<bool> stopped; + +private: + ThreadImpl(ThreadImpl const&); + ThreadImpl& operator=(ThreadImpl const&); + + Mutex mutex; +}; + +struct ThreadFunctionImpl : ThreadImpl { + ThreadFunctionImpl(std::function<void()> function, String name) + : ThreadImpl(wrapFunction(move(function)), move(name)) {} + + std::function<void()> wrapFunction(std::function<void()> function) { + return [function = move(function), this]() { + try { + function(); + } catch (...) { + exception = std::current_exception(); + } + }; + } + + std::exception_ptr exception; +}; + +struct MutexImpl { + MutexImpl() { + InitializeCriticalSection(&criticalSection); + } + + ~MutexImpl() { + DeleteCriticalSection(&criticalSection); + } + + void lock() { + EnterCriticalSection(&criticalSection); + } + + void unlock() { + LeaveCriticalSection(&criticalSection); + } + + bool tryLock() { + return TryEnterCriticalSection(&criticalSection); + } + + CRITICAL_SECTION criticalSection; +}; + +struct ConditionVariableImpl { + ConditionVariableImpl() { + if (g_threadSupport.nativeConditionVariables) { + m_impl = make_unique<NativeImpl>(); + } else { + m_impl = make_unique<EmulatedImpl>(); + } + } + + void wait(Mutex& mutex) { + m_impl->wait(mutex); + } + + void wait(Mutex& mutex, unsigned millis) { + m_impl->wait(mutex, millis); + } + + void signal() { + m_impl->signal(); + } + + void broadcast() { + m_impl->broadcast(); + } + +private: + struct Impl { + virtual ~Impl() {} + + virtual void wait(Mutex& mutex) = 0; + virtual void wait(Mutex& mutex, unsigned millis) = 0; + virtual void signal() = 0; + virtual void broadcast() = 0; + }; + + struct NativeImpl : Impl { + NativeImpl() { + g_threadSupport.initializeConditionVariable(&conditionVariable); + } + + void wait(Mutex& mutex) override { + g_threadSupport.sleepConditionVariableCS(&conditionVariable, &mutex.m_impl->criticalSection, INFINITE); + } + + void wait(Mutex& mutex, unsigned millis) override { + g_threadSupport.sleepConditionVariableCS(&conditionVariable, &mutex.m_impl->criticalSection, millis); + } + + void signal() override { + g_threadSupport.wakeConditionVariable(&conditionVariable); + } + + void broadcast() override { + g_threadSupport.wakeAllConditionVariable(&conditionVariable); + } + + CONDITIONAL_VARIABLE conditionVariable; + }; + + struct EmulatedImpl : Impl { + EmulatedImpl() { + numThreads = 0; + isBroadcasting = 0; + threadSemaphore = CreateSemaphore(NULL, // no security + 0, // initially 0 + 0x7fffffff, // max count + NULL); // unnamed + + InitializeCriticalSection(&numThreadsConditionMutex); + + broadcastDone = CreateEvent(NULL, // no security + FALSE, // auto-reset + FALSE, // non-signaled initially + NULL); // unnamed + } + + virtual ~EmulatedImpl() { + CloseHandle(threadSemaphore); + CloseHandle(broadcastDone); + DeleteCriticalSection(&numThreadsConditionMutex); + } + + void wait(Mutex& mutex) override { + // Avoid race conditions. + EnterCriticalSection(&numThreadsConditionMutex); + numThreads++; + LeaveCriticalSection(&numThreadsConditionMutex); + + // Release the mutex and waits on the semaphore until signal or broadcast + // are called by another thread. + LeaveCriticalSection(&mutex.m_impl->criticalSection); + WaitForSingleObject(threadSemaphore, INFINITE); + + // Reacquire lock to avoid race conditions. + EnterCriticalSection(&numThreadsConditionMutex); + + // We're no longer waiting... + numThreads--; + + // Check to see if we're the last waiter after broadcast + bool last_waiter = isBroadcasting && numThreads == 0; + + LeaveCriticalSection(&numThreadsConditionMutex); + + // If we're the last waiter thread during this particular broadcast + // then let all the other threads proceed. + if (last_waiter) + SetEvent(broadcastDone); + EnterCriticalSection(&mutex.m_impl->criticalSection); + } + + void wait(Mutex& mutex, unsigned millis) override { + // Avoid race conditions. + EnterCriticalSection(&numThreadsConditionMutex); + numThreads++; + LeaveCriticalSection(&numThreadsConditionMutex); + + // Release the mutex and waits on the semaphore until signal or broadcast + // are called by another thread. + LeaveCriticalSection(&mutex.m_impl->criticalSection); + WaitForSingleObject(threadSemaphore, millis); + + // Reacquire lock to avoid race conditions. + EnterCriticalSection(&numThreadsConditionMutex); + + // We're no longer waiting... + numThreads--; + + // Check to see if we're the last waiter after broadcast + bool last_waiter = isBroadcasting && numThreads == 0; + + LeaveCriticalSection(&numThreadsConditionMutex); + + // If we're the last waiter thread during this particular broadcast + // then let all the other threads proceed. + if (last_waiter) + SetEvent(broadcastDone); + EnterCriticalSection(&mutex.m_impl->criticalSection); + } + + void signal() override { + EnterCriticalSection(&numThreadsConditionMutex); + bool have_waiters = numThreads > 0; + LeaveCriticalSection(&numThreadsConditionMutex); + + // If there aren't any waiters, then this is a no-op. + if (have_waiters) + ReleaseSemaphore(threadSemaphore, 1, 0); + } + + void broadcast() override { + // This is needed to ensure that <numThreads> and <isBroadcasting> are + // consistent relative to each other. + EnterCriticalSection(&numThreadsConditionMutex); + bool have_waiters = 0; + + if (numThreads > 0) { + // We are broadcasting, even if there is just one waiter... + // Record that we are broadcasting + isBroadcasting = 1; + have_waiters = 1; + } + + if (have_waiters) { + // Wake up all the waiters atomically. + ReleaseSemaphore(threadSemaphore, numThreads, 0); + + LeaveCriticalSection(&numThreadsConditionMutex); + + // Wait for all the awakened threads to acquire the counting + // semaphore. + WaitForSingleObject(broadcastDone, INFINITE); + // This assignment is okay, even without the <numThreadsConditionMutex> + // held + // because no other waiter threads can wake up to access it. + isBroadcasting = 0; + } else { + LeaveCriticalSection(&numThreadsConditionMutex); + } + } + + // Number of waiting threads. + int numThreads; + + // Serialize access to <numThreads>. + CRITICAL_SECTION numThreadsConditionMutex; + + // Semaphore used to queue up threads waiting for the condition to + // become signaled. + HANDLE threadSemaphore; + + // An auto-reset event used by the broadcast/signal thread to wait + // for all the waiting thread(s) to wake up and be released from the + // semaphore. + HANDLE broadcastDone; + + // Keeps track of whether we were broadcasting or signaling. This + // allows us to optimize the code if we're just signaling. + size_t isBroadcasting; + }; + + unique_ptr<Impl> m_impl; +}; + +struct RecursiveMutexImpl { + RecursiveMutexImpl() { + InitializeCriticalSection(&criticalSection); + } + + ~RecursiveMutexImpl() { + DeleteCriticalSection(&criticalSection); + } + + void lock() { + EnterCriticalSection(&criticalSection); + } + + void unlock() { + LeaveCriticalSection(&criticalSection); + } + + bool tryLock() { + return TryEnterCriticalSection(&criticalSection); + } + + CRITICAL_SECTION criticalSection; +}; + +void Thread::sleepPrecise(unsigned msecs) { + int64_t now = Time::monotonicMilliseconds(); + int64_t deadline = now + msecs; + + while (deadline - now > 10) { + Sleep(deadline - now - 10); + now = Time::monotonicMilliseconds(); + } + + while (deadline > now) { + if (deadline - now >= 2) + Sleep((deadline - now) / 2); + else + Sleep(0); + now = Time::monotonicMilliseconds(); + } +} + +void Thread::sleep(unsigned msecs) { + Sleep(msecs); +} + +void Thread::yield() { + SwitchToThread(); +} + +unsigned Thread::numberOfProcessors() { + SYSTEM_INFO info; + GetSystemInfo(&info); + return info.dwNumberOfProcessors; +} + +Thread::Thread(String const& name) { + m_impl.reset(new ThreadImpl([this]() { run(); }, name)); +} + +Thread::Thread(Thread&&) = default; + +Thread::~Thread() {} + +Thread& Thread::operator=(Thread&&) = default; + +bool Thread::start() { + return m_impl->start(); +} + +bool Thread::join() { + return m_impl->join(); +} + +String Thread::name() { + return m_impl->name; +} + +bool Thread::isJoined() const { + return m_impl->thread == INVALID_HANDLE_VALUE; +} + +bool Thread::isRunning() const { + return !m_impl->stopped; +} + +ThreadFunction<void>::ThreadFunction() {} + +ThreadFunction<void>::ThreadFunction(ThreadFunction&&) = default; + +ThreadFunction<void>::ThreadFunction(function<void()> function, String const& name) { + m_impl.reset(new ThreadFunctionImpl(move(function), name)); + m_impl->start(); +} + +ThreadFunction<void>::~ThreadFunction() { + finish(); +} + +ThreadFunction<void>& ThreadFunction<void>::operator=(ThreadFunction&&) = default; + +void ThreadFunction<void>::finish() { + if (m_impl) { + m_impl->join(); + + if (m_impl->exception) + std::rethrow_exception(take(m_impl->exception)); + } +} + +bool ThreadFunction<void>::isFinished() const { + return !m_impl || m_impl->thread == INVALID_HANDLE_VALUE; +} + +bool ThreadFunction<void>::isRunning() const { + return m_impl && !m_impl->stopped; +} + +ThreadFunction<void>::operator bool() const { + return !isFinished(); +} + +String ThreadFunction<void>::name() { + if (m_impl) + return m_impl->name; + else + return ""; +} + +Mutex::Mutex() : m_impl(new MutexImpl()) {} + +Mutex::Mutex(Mutex&&) = default; + +Mutex::~Mutex() {} + +Mutex& Mutex::operator=(Mutex&&) = default; + +void Mutex::lock() { + m_impl->lock(); +} + +bool Mutex::tryLock() { + return m_impl->tryLock(); +} + +void Mutex::unlock() { + m_impl->unlock(); +} + +ConditionVariable::ConditionVariable() : m_impl(new ConditionVariableImpl()) {} + +ConditionVariable::ConditionVariable(ConditionVariable&&) = default; + +ConditionVariable::~ConditionVariable() {} + +ConditionVariable& ConditionVariable::operator=(ConditionVariable&&) = default; + +void ConditionVariable::wait(Mutex& mutex, Maybe<unsigned> millis) { + if (millis) + m_impl->wait(mutex, *millis); + else + m_impl->wait(mutex); +} + +void ConditionVariable::signal() { + m_impl->signal(); +} + +void ConditionVariable::broadcast() { + m_impl->broadcast(); +} + +RecursiveMutex::RecursiveMutex() : m_impl(new RecursiveMutexImpl()) {} + +RecursiveMutex::RecursiveMutex(RecursiveMutex&&) = default; + +RecursiveMutex::~RecursiveMutex() {} + +RecursiveMutex& RecursiveMutex::operator=(RecursiveMutex&&) = default; + +void RecursiveMutex::lock() { + m_impl->lock(); +} + +bool RecursiveMutex::tryLock() { + return m_impl->tryLock(); +} + +void RecursiveMutex::unlock() { + m_impl->unlock(); +} + +} diff --git a/source/core/StarTickRateMonitor.cpp b/source/core/StarTickRateMonitor.cpp new file mode 100644 index 0000000..78a423a --- /dev/null +++ b/source/core/StarTickRateMonitor.cpp @@ -0,0 +1,87 @@ +#include "StarTickRateMonitor.hpp" +#include "StarTime.hpp" + +namespace Star { + +TickRateMonitor::TickRateMonitor(double window) : m_window(window) { + reset(); +} + +double TickRateMonitor::window() const { + return m_window; +} + +void TickRateMonitor::reset() { + m_lastTick = Time::monotonicTime() - m_window; + m_ticks = 0; +} + +double TickRateMonitor::tick(unsigned count) { + double currentTime = Time::monotonicTime(); + + if (m_lastTick > currentTime) { + m_lastTick = currentTime - m_window; + m_ticks = 0; + } else if (m_lastTick < currentTime) { + double timePast = currentTime - m_lastTick; + double rate = m_ticks / m_window; + m_ticks = max(0.0, m_ticks - timePast * rate); + m_lastTick = currentTime; + } + + m_ticks += count; + + return m_ticks / m_window; +} + +double TickRateMonitor::rate() const { + return TickRateMonitor(*this).tick(0); +} + +TickRateApproacher::TickRateApproacher(double targetTickRate, double window) + : m_tickRateMonitor(window), m_targetTickRate(targetTickRate) {} + +double TickRateApproacher::window() const { + return m_tickRateMonitor.window(); +} + +void TickRateApproacher::setWindow(double window) { + if (window != m_tickRateMonitor.window()) { + m_tickRateMonitor = TickRateMonitor(window); + tick(m_targetTickRate * window); + } +} + +double TickRateApproacher::targetTickRate() const { + return m_targetTickRate; +} + +void TickRateApproacher::setTargetTickRate(double targetTickRate) { + m_targetTickRate = targetTickRate; +} + +void TickRateApproacher::reset() { + setWindow(window()); +} + +double TickRateApproacher::tick(unsigned count) { + return m_tickRateMonitor.tick(count); +} + +double TickRateApproacher::rate() const { + return m_tickRateMonitor.rate(); +} + +double TickRateApproacher::ticksBehind() { + return (m_targetTickRate - m_tickRateMonitor.rate()) * window(); +} + +double TickRateApproacher::ticksAhead() { + return -ticksBehind(); +} + +double TickRateApproacher::spareTime() { + return ticksAhead() / m_targetTickRate; +} + +} diff --git a/source/core/StarTickRateMonitor.hpp b/source/core/StarTickRateMonitor.hpp new file mode 100644 index 0000000..a50ff5e --- /dev/null +++ b/source/core/StarTickRateMonitor.hpp @@ -0,0 +1,78 @@ +#ifndef STAR_TICK_RATE_MONITOR_HPP +#define STAR_TICK_RATE_MONITOR_HPP + +#include "StarList.hpp" + +namespace Star { + +// Monitors the rate at which 'tick()' is called in wall-clock seconds. +class TickRateMonitor { +public: + // 'window' controls the dropoff at which 'rate' will approach zero if tick + // is not called, measured in seconds. + TickRateMonitor(double window); + + double window() const; + + // Resets to a zero tick-rate state + void reset(); + + // Ticks the given number of times, returns the current rate. + double tick(unsigned count = 1); + + // Returns the rate as of the *current* time, not the time of the last tick. + double rate() const; + +private: + void dropOff(double currentTime); + + double m_window; + double m_lastTick; + double m_ticks; +}; + +// Helps tick at as close as possible to a given tick rate +class TickRateApproacher { +public: + TickRateApproacher(double targetTickRate, double window); + + // The TickRateMonitor window influences how long the TickRateApproacher will + // try and speed up or slow down the tick rate to match the target tick rate. + // It should be chosen so that it is not so short that the actual target rate + // drifts, but not too long so that the rate returns to normal quickly enough + // with outliers. + double window() const; + // Setting the window to a new value will reset the TickRateApproacher + void setWindow(double window); + + double targetTickRate() const; + void setTargetTickRate(double targetTickRate); + + // Resets such that the current tick rate is assumed to be perfectly at the + // target. + void reset(); + + double tick(unsigned count = 1); + double rate() const; + + // How many ticks we currently should perform, so that if each tick happened + // instantly, we would be as close to the target tick rate as possible. If + // we are ahead, may be negative. + double ticksBehind(); + + // The negative of ticksBehind, is positive for how many ticks ahead we + // currently are. + double ticksAhead(); + + // How much spare time we have until the tick rate will begin to be behind + // the target tick rate. + double spareTime(); + +private: + TickRateMonitor m_tickRateMonitor; + double m_targetTickRate; +}; + +}; + +#endif diff --git a/source/core/StarTime.cpp b/source/core/StarTime.cpp new file mode 100644 index 0000000..b8d01a4 --- /dev/null +++ b/source/core/StarTime.cpp @@ -0,0 +1,208 @@ +#include "StarTime.hpp" +#include "StarMathCommon.hpp" +#include "StarLexicalCast.hpp" + +namespace Star { + +double Time::timeSinceEpoch() { + return ticksToSeconds(epochTicks(), epochTickFrequency()); +} + +int64_t Time::millisecondsSinceEpoch() { + return ticksToMilliseconds(epochTicks(), epochTickFrequency()); +} + +double Time::monotonicTime() { + return ticksToSeconds(monotonicTicks(), monotonicTickFrequency()); +} + +int64_t Time::monotonicMilliseconds() { + return ticksToMilliseconds(monotonicTicks(), monotonicTickFrequency()); +} + +String Time::printDuration(double time) { + String hours; + String minutes; + String seconds; + String milliseconds; + + if (time >= 3600) { + int numHours = (int)time / 3600; + hours = strf("%d hour%s", numHours, numHours == 1 ? "" : "s"); + } + if (time >= 60) { + int numMinutes = (int)(time / 60) % 60; + minutes = strf("%s minute%s", numMinutes, numMinutes == 1 ? "" : "s"); + } + if (time >= 1) { + int numSeconds = (int)time % 60; + seconds = strf("%s second%s", numSeconds, numSeconds == 1 ? "" : "s"); + } + + int numMilliseconds = round(time * 1000); + milliseconds = strf("%s millisecond%s", numMilliseconds, numMilliseconds == 1 ? "" : "s"); + + return String::joinWith(", ", hours, minutes, seconds, milliseconds); +} + +String Time::printCurrentDateAndTime(String format) { + return printDateAndTime(epochTicks(), format); +} + +double Time::ticksToSeconds(int64_t ticks, int64_t tickFrequency) { + return ticks / (double)tickFrequency; +} + +int64_t Time::ticksToMilliseconds(int64_t ticks, int64_t tickFrequency) { + int64_t ticksPerMs = (tickFrequency + 500) / 1000; + return (ticks + ticksPerMs / 2) / ticksPerMs; +} + +int64_t Time::secondsToTicks(double seconds, int64_t tickFrequency) { + return round(seconds * tickFrequency); +} + +int64_t Time::millisecondsToTicks(int64_t milliseconds, int64_t tickFrequency) { + return milliseconds * ((tickFrequency + 500) / 1000); +} + +Clock::Clock(bool start) { + m_elapsedTicks = 0; + m_running = false; + if (start) + Clock::start(); +} + +Clock::Clock(Clock const& clock) { + operator=(clock); +} + +Clock& Clock::operator=(Clock const& clock) { + m_elapsedTicks = clock.m_elapsedTicks; + m_lastTicks = clock.m_lastTicks; + m_running = clock.m_running; + + return *this; +} + +void Clock::reset() { + RecursiveMutexLocker locker(m_mutex); + updateElapsed(); + m_elapsedTicks = 0; +} + +void Clock::stop() { + RecursiveMutexLocker locker(m_mutex); + m_lastTicks.reset(); + m_running = false; +} + +void Clock::start() { + RecursiveMutexLocker locker(m_mutex); + m_running = true; + updateElapsed(); +} + +bool Clock::running() const { + RecursiveMutexLocker locker(m_mutex); + return m_running; +} + +double Clock::time() const { + RecursiveMutexLocker locker(m_mutex); + updateElapsed(); + return Time::ticksToSeconds(m_elapsedTicks, Time::monotonicTickFrequency()); +} + +int64_t Clock::milliseconds() const { + RecursiveMutexLocker locker(m_mutex); + updateElapsed(); + return Time::ticksToMilliseconds(m_elapsedTicks, Time::monotonicTickFrequency()); +} + +void Clock::setTime(double time) { + RecursiveMutexLocker locker(m_mutex); + updateElapsed(); + m_elapsedTicks = Time::secondsToTicks(time, Time::monotonicTickFrequency()); +} + +void Clock::setMilliseconds(int64_t millis) { + RecursiveMutexLocker locker(m_mutex); + updateElapsed(); + m_elapsedTicks = Time::millisecondsToTicks(millis, Time::monotonicTickFrequency()); +} + +void Clock::adjustTime(double timeAdjustment) { + RecursiveMutexLocker locker(m_mutex); + setTime(time() + timeAdjustment); +} + +void Clock::adjustMilliseconds(int64_t millisAdjustment) { + RecursiveMutexLocker locker(m_mutex); + setMilliseconds(milliseconds() + millisAdjustment); +} + +void Clock::updateElapsed() const { + if (!m_running) + return; + + int64_t currentTicks = Time::monotonicTicks(); + + if (m_lastTicks) + m_elapsedTicks += (currentTicks - *m_lastTicks); + + m_lastTicks = currentTicks; +} + +Timer Timer::withTime(double timeLeft, bool start) { + Timer timer; + timer.setTime(-timeLeft); + if (start) + timer.start(); + return timer; +} + +Timer Timer::withMilliseconds(int64_t millis, bool start) { + Timer timer; + timer.setMilliseconds(-millis); + if (start) + timer.start(); + return timer; +} + +Timer::Timer() : Clock(false) { + setTime(0.0); +} + +Timer::Timer(Timer const& timer) + : Clock(timer) {} + +void Timer::restart(double timeLeft) { + Clock::setTime(-timeLeft); + Clock::start(); +} + +void Timer::restartWithMilliseconds(int64_t millisecondsLeft) { + Clock::setMilliseconds(-millisecondsLeft); + Clock::start(); +} + +double Timer::timeLeft(bool negative) const { + double timeLeft = -Clock::time(); + if (!negative) + timeLeft = max(0.0, timeLeft); + return timeLeft; +} + +int64_t Timer::millisecondsLeft(bool negative) const { + int64_t millisLeft = -Clock::milliseconds(); + if (!negative) + millisLeft = max<int64_t>(0, millisLeft); + return millisLeft; +} + +bool Timer::timeUp() const { + return Clock::time() >= 0.0; +} + +} diff --git a/source/core/StarTime.hpp b/source/core/StarTime.hpp new file mode 100644 index 0000000..e4dd274 --- /dev/null +++ b/source/core/StarTime.hpp @@ -0,0 +1,108 @@ +#ifndef STAR_TIME_HPP +#define STAR_TIME_HPP + +#include "StarThread.hpp" + +namespace Star { + +STAR_CLASS(Clock); +STAR_CLASS(Timer); + +namespace Time { + double timeSinceEpoch(); + int64_t millisecondsSinceEpoch(); + + double monotonicTime(); + int64_t monotonicMilliseconds(); + + // Pretty print a duration of time (In days, hours, minutes, seconds, and milliseconds) + String printDuration(double time); + + // Pretty print a given date and time + String printDateAndTime(int64_t epochTicks, String format = "<year>-<month>-<day> <hours>:<minutes>:<seconds>.<millis>"); + String printCurrentDateAndTime(String format = "<year>-<month>-<day> <hours>:<minutes>:<seconds>.<millis>"); + + // Ticks since unix epoch + int64_t epochTicks(); + // Epoch ticks per second, static throughout application lifetime. + int64_t epochTickFrequency(); + + // Ticks since unspecified time before program start + int64_t monotonicTicks(); + // Monotonic ticks per second, static throughout application lifetime. + int64_t monotonicTickFrequency(); + + double ticksToSeconds(int64_t ticks, int64_t tickFrequency); + int64_t ticksToMilliseconds(int64_t ticks, int64_t tickFrequency); + int64_t secondsToTicks(double seconds, int64_t tickFrequency); + int64_t millisecondsToTicks(int64_t milliseconds, int64_t tickFrequency); +} + +// Keeps track of elapsed real time since a given moment. Guaranteed +// monotonically increasing and thread safe. +class Clock { +public: + explicit Clock(bool start = true); + + Clock(Clock const& clock); + + Clock& operator=(Clock const& clock); + + // Resets clock to 0 time + void reset(); + + void stop(); + void start(); + + bool running() const; + + double time() const; + int64_t milliseconds() const; + + // Override actual elapsed time with the given time. + void setTime(double time); + void setMilliseconds(int64_t millis); + + // Warp the clock backwards or forwards + void adjustTime(double timeAdjustment); + void adjustMilliseconds(int64_t millisAdjustment); + +private: + void updateElapsed() const; + + mutable RecursiveMutex m_mutex; + mutable int64_t m_elapsedTicks; + mutable Maybe<int64_t> m_lastTicks; + bool m_running; +}; + +// An instance of Clock that counts down a given amount of time +class Timer : private Clock { +public: + static Timer withTime(double timeLeft, bool start = true); + static Timer withMilliseconds(int64_t millis, bool start = true); + + // Constructs a stopped timer whose time is up. + Timer(); + Timer(Timer const& timer); + + // Start the timer with the given time left. + void restart(double timeLeft); + void restartWithMilliseconds(int64_t millisecondsLeft); + + // Time remaining on the timer. If negative is true, will return negative + // time values after the timer is up, if false it stops at zero. + double timeLeft(bool negative = false) const; + int64_t millisecondsLeft(bool negative = false) const; + + // Is the time remaining <= 0.0? + bool timeUp() const; + + using Clock::stop; + using Clock::start; + using Clock::running; +}; + +} + +#endif diff --git a/source/core/StarTime_unix.cpp b/source/core/StarTime_unix.cpp new file mode 100644 index 0000000..50ed954 --- /dev/null +++ b/source/core/StarTime_unix.cpp @@ -0,0 +1,96 @@ +#include "StarTime.hpp" +#include "StarMathCommon.hpp" + +#include <sys/time.h> + +#ifdef STAR_SYSTEM_MACOS +#include <mach/mach_time.h> +#else +#include <time.h> +#endif + +namespace Star { + +String Time::printDateAndTime(int64_t epochTicks, String format) { + // playing fast and loose with the standard here... + time_t requestedTime = epochTicks / epochTickFrequency(); + struct tm ptm; + localtime_r(&requestedTime, &ptm); + + return format.replaceTags(StringMap<String>{ + {"year", strf("%04d", ptm.tm_year + 1900)}, + {"month", strf("%02d", ptm.tm_mon + 1)}, + {"day", strf("%02d", ptm.tm_mday)}, + {"hours", strf("%02d", ptm.tm_hour)}, + {"minutes", strf("%02d", ptm.tm_min)}, + {"seconds", strf("%02d", ptm.tm_sec)}, + {"millis", strf("%03d", (epochTicks % epochTickFrequency()) / (epochTickFrequency() / 1000))} + }); +} + +int64_t Time::epochTicks() { + timeval tv; + gettimeofday(&tv, NULL); + return (int64_t)tv.tv_sec * 1'000'000 + tv.tv_usec; +} + +int64_t Time::epochTickFrequency() { + return 1'000'000; +} + +#ifdef STAR_SYSTEM_MACOS + +struct MonotonicClock { + MonotonicClock() { + mach_timebase_info(&timebaseInfo); + }; + + int64_t ticks() const { + int64_t t = mach_absolute_time(); + return (t / 100) * timebaseInfo.numer / timebaseInfo.denom; + } + + int64_t frequency() const { + // hard coded to 100ns increments + return 10'000'000; + } + + mach_timebase_info_data_t timebaseInfo; +}; + +#else + +struct MonotonicClock { + MonotonicClock() { + timespec ts; + clock_getres(CLOCK_MONOTONIC, &ts); + starAssert(ts.tv_sec == 0); + storedFrequency = 1'000'000'000 / ts.tv_nsec; + }; + + int64_t ticks() const { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ts.tv_sec * storedFrequency + ts.tv_nsec * storedFrequency / 1'000'000'000; + } + + int64_t frequency() const { + return storedFrequency; + } + + int64_t storedFrequency; +}; + +#endif + +static MonotonicClock g_monotonicClock; + +int64_t Time::monotonicTicks() { + return g_monotonicClock.ticks(); +} + +int64_t Time::monotonicTickFrequency() { + return g_monotonicClock.frequency(); +} + +} diff --git a/source/core/StarTime_windows.cpp b/source/core/StarTime_windows.cpp new file mode 100644 index 0000000..2b3e797 --- /dev/null +++ b/source/core/StarTime_windows.cpp @@ -0,0 +1,68 @@ +#include "StarTime.hpp" +#include "StarLexicalCast.hpp" +#include "StarMathCommon.hpp" + +#include <ctime> +#include <windows.h> + +namespace Star { + +String Time::printDateAndTime(int64_t epochTicks, String format) { + // playing fast and loose with the standard here... + time_t requestedTime = epochTicks / epochTickFrequency(); + struct tm* ptm; + ptm = localtime(&requestedTime); + + return format.replaceTags(StringMap<String>{ + {"year", strf("%04d", ptm->tm_year + 1900)}, + {"month", strf("%02d", ptm->tm_mon + 1)}, + {"day", strf("%02d", ptm->tm_mday)}, + {"hours", strf("%02d", ptm->tm_hour)}, + {"minutes", strf("%02d", ptm->tm_min)}, + {"seconds", strf("%02d", ptm->tm_sec)}, + {"millis", strf("%03d", (epochTicks % epochTickFrequency()) / (epochTickFrequency() / 1000))}, + }); +} + +int64_t Time::epochTicks() { + FILETIME ft_now; + GetSystemTimeAsFileTime(&ft_now); + LONGLONG now = (LONGLONG)ft_now.dwLowDateTime + ((LONGLONG)(ft_now.dwHighDateTime) << 32LL); + now -= 116444736000000000LL; + return now; +} + +int64_t Time::epochTickFrequency() { + return 10000000LL; +} + +struct MonotonicClock { + MonotonicClock() { + QueryPerformanceFrequency(&freq); + }; + + int64_t ticks() const { + LARGE_INTEGER ticks; + QueryPerformanceCounter(&ticks); + return ticks.QuadPart; + } + + int64_t frequency() const { + return freq.QuadPart; + } + + LARGE_INTEGER freq; +}; + +static MonotonicClock g_monotonicClock; + +int64_t Time::monotonicTicks() { + return g_monotonicClock.ticks(); +} + +int64_t Time::monotonicTickFrequency() { + return g_monotonicClock.frequency(); +} + + +} diff --git a/source/core/StarTtlCache.hpp b/source/core/StarTtlCache.hpp new file mode 100644 index 0000000..c70741a --- /dev/null +++ b/source/core/StarTtlCache.hpp @@ -0,0 +1,203 @@ +#ifndef STAR_TTL_CACHE_HPP +#define STAR_TTL_CACHE_HPP + +#include "StarLruCache.hpp" +#include "StarTime.hpp" +#include "StarRandom.hpp" + +namespace Star { + +template <typename LruCacheType> +class TtlCacheBase { +public: + typedef typename LruCacheType::Key Key; + typedef typename LruCacheType::Value::second_type Value; + + typedef function<Value(Key const&)> ProducerFunction; + + TtlCacheBase(int64_t timeToLive = 10000, int timeSmear = 1000, size_t maxSize = NPos, bool ttlUpdateEnabled = true); + + int64_t timeToLive() const; + void setTimeToLive(int64_t timeToLive); + + int timeSmear() const; + void setTimeSmear(int timeSmear); + + // If a max size is set, this cache also acts as an LRU cache with the given + // maximum size. + size_t maxSize() const; + void setMaxSize(size_t maxSize = NPos); + + size_t currentSize() const; + + List<Key> keys() const; + List<Value> values() const; + + // If ttlUpdateEnabled is false, then the time to live for entries will not + // be updated on access. + bool ttlUpdateEnabled() const; + void setTtlUpdateEnabled(bool enabled); + + // If the value is in the cache, returns it and updates the access time, + // otherwise returns nullptr. + Value* ptr(Key const& key); + + // Put the given value into the cache. + void set(Key const& key, Value value); + // Removes the given value from the cache. If found and removed, returns + // true. + bool remove(Key const& key); + + // Remove all key / value pairs matching a filter. + void removeWhere(function<bool(Key const&, Value&)> filter); + + // If the value for the key is not found in the cache, produce it with the + // given producer. Producer should take the key as an argument and return + // the Value. + template <typename Producer> + Value& get(Key const& key, Producer producer); + + void clear(); + + // Cleanup any cached entries that are older than their time to live, if the + // refreshFilter is given, things that match the refreshFilter instead have + // their ttl refreshed rather than being removed. + void cleanup(function<bool(Key const&, Value const&)> refreshFilter = {}); + +private: + LruCacheType m_cache; + int64_t m_timeToLive; + int m_timeSmear; + bool m_ttlUpdateEnabled; +}; + +template <typename Key, typename Value, typename Compare = std::less<Key>, typename Allocator = BlockAllocator<pair<Key const, pair<int64_t, Value>>, 1024>> +using TtlCache = TtlCacheBase<LruCache<Key, pair<int64_t, Value>, Compare, Allocator>>; + +template <typename Key, typename Value, typename Hash = Star::hash<Key>, typename Equals = std::equal_to<Key>, typename Allocator = BlockAllocator<pair<Key const, pair<int64_t, Value>>, 1024>> +using HashTtlCache = TtlCacheBase<HashLruCache<Key, pair<int64_t, Value>, Hash, Equals, Allocator>>; + +template <typename LruCacheType> +TtlCacheBase<LruCacheType>::TtlCacheBase(int64_t timeToLive, int timeSmear, size_t maxSize, bool ttlUpdateEnabled) { + m_cache.setMaxSize(maxSize); + m_timeToLive = timeToLive; + m_timeSmear = timeSmear; + m_ttlUpdateEnabled = ttlUpdateEnabled; +} + +template <typename LruCacheType> +int64_t TtlCacheBase<LruCacheType>::timeToLive() const { + return m_timeToLive; +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::setTimeToLive(int64_t timeToLive) { + m_timeToLive = timeToLive; +} + +template <typename LruCacheType> +int TtlCacheBase<LruCacheType>::timeSmear() const { + return m_timeSmear; +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::setTimeSmear(int timeSmear) { + m_timeSmear = timeSmear; +} + +template <typename LruCacheType> +bool TtlCacheBase<LruCacheType>::ttlUpdateEnabled() const { + return m_ttlUpdateEnabled; +} + +template <typename LruCacheType> +size_t TtlCacheBase<LruCacheType>::maxSize() const { + return m_cache.maxSize(); +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::setMaxSize(size_t maxSize) { + m_cache.setMaxSize(maxSize); +} + +template <typename LruCacheType> +size_t TtlCacheBase<LruCacheType>::currentSize() const { + return m_cache.currentSize(); +} + +template <typename LruCacheType> +auto TtlCacheBase<LruCacheType>::keys() const -> List<Key> { + return m_cache.keys(); +} + +template <typename LruCacheType> +auto TtlCacheBase<LruCacheType>::values() const -> List<Value> { + List<Value> values; + for (auto& p : m_cache.values()) + values.append(move(p.second)); + return values; +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::setTtlUpdateEnabled(bool enabled) { + m_ttlUpdateEnabled = enabled; +} + +template <typename LruCacheType> +auto TtlCacheBase<LruCacheType>::ptr(Key const& key) -> Value * { + if (auto p = m_cache.ptr(key)) { + if (m_ttlUpdateEnabled) + p->first = Time::monotonicMilliseconds() + Random::randInt(-m_timeSmear, m_timeSmear); + return &p->second; + } + return nullptr; +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::set(Key const& key, Value value) { + m_cache.set(key, make_pair(Time::monotonicMilliseconds() + Random::randInt(-m_timeSmear, m_timeSmear), value)); +} + +template <typename LruCacheType> +bool TtlCacheBase<LruCacheType>::remove(Key const& key) { + return m_cache.remove(key); +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::removeWhere(function<bool(Key const&, Value&)> filter) { + m_cache.removeWhere([&filter](auto const& key, auto& value) { return filter(key, value.second); }); +} + +template <typename LruCacheType> +template <typename Producer> +auto TtlCacheBase<LruCacheType>::get(Key const& key, Producer producer) -> Value & { + auto& value = m_cache.get(key, [producer](Key const& key) { + return pair<int64_t, Value>(0, producer(key)); + }); + if (value.first == 0 || m_ttlUpdateEnabled) + value.first = Time::monotonicMilliseconds() + Random::randInt(-m_timeSmear, m_timeSmear); + return value.second; +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::clear() { + m_cache.clear(); +} + +template <typename LruCacheType> +void TtlCacheBase<LruCacheType>::cleanup(function<bool(Key const&, Value const&)> refreshFilter) { + int64_t currentTime = Time::monotonicMilliseconds(); + m_cache.removeWhere([&](auto const& key, auto& value) { + if (refreshFilter && refreshFilter(key, value.second)) { + value.first = currentTime; + } else { + if (currentTime - value.first > m_timeToLive) + return true; + } + return false; + }); +} + +} + +#endif diff --git a/source/core/StarUdp.cpp b/source/core/StarUdp.cpp new file mode 100644 index 0000000..9ff7b48 --- /dev/null +++ b/source/core/StarUdp.cpp @@ -0,0 +1,84 @@ +#include "StarUdp.hpp" +#include "StarLogging.hpp" +#include "StarNetImpl.hpp" + +namespace Star { + +UdpSocket::UdpSocket(NetworkMode networkMode) : Socket(SocketType::Udp, networkMode) {} + +size_t UdpSocket::receive(HostAddressWithPort* address, char* data, size_t datasize) { + ReadLocker locker(m_mutex); + checkOpen("UdpSocket::receive"); + + int flags = 0; + int len; + struct sockaddr_storage sockAddr; + socklen_t sockAddrLen = sizeof(sockAddr); + + len = ::recvfrom(m_impl->socketDesc, data, datasize, flags, (struct sockaddr*)&sockAddr, &sockAddrLen); + + if (len < 0) { + if (!isActive()) + throw SocketClosedException("Connection closed"); + else if (netErrorInterrupt()) + len = 0; + else + throw NetworkException(strf("udp recv error: %s", netErrorString())); + } + + if (address) + setAddressFromNative(*address, m_localAddress.address().mode(), &sockAddr); + + return len; +} + +size_t UdpSocket::send(HostAddressWithPort const& address, char const* data, size_t size) { + ReadLocker locker(m_mutex); + checkOpen("UdpSocket::send"); + + struct sockaddr_storage sockAddr; + socklen_t sockAddrLen; + setNativeFromAddress(address, &sockAddr, &sockAddrLen); + + int len = ::sendto(m_impl->socketDesc, data, size, 0, (struct sockaddr*)&sockAddr, sockAddrLen); + if (len < 0) { + if (!isActive()) + throw SocketClosedException("Connection closed"); + else if (netErrorInterrupt()) + len = 0; + else + throw NetworkException(strf("udp send error: %s", netErrorString())); + } + + return len; +} + +UdpServer::UdpServer(HostAddressWithPort const& address) + : m_hostAddress(address), m_listenSocket(make_shared<UdpSocket>(m_hostAddress.address().mode())) { + m_listenSocket->setNonBlocking(true); + m_listenSocket->bind(m_hostAddress); + Logger::debug("UdpServer listening on: %s", m_hostAddress); +} + +UdpServer::~UdpServer() { + close(); +} + +size_t UdpServer::receive(HostAddressWithPort* address, char* data, size_t bufsize, unsigned timeout) { + Socket::poll({{m_listenSocket, {true, false}}}, timeout); + return m_listenSocket->receive(address, data, bufsize); +} + +size_t UdpServer::send(HostAddressWithPort const& address, char const* data, size_t len) { + return m_listenSocket->send(address, data, len); +} + +void UdpServer::close() { + m_listenSocket->close(); +} + +bool UdpServer::isListening() const { + return m_listenSocket->isActive(); +} + +} diff --git a/source/core/StarUdp.hpp b/source/core/StarUdp.hpp new file mode 100644 index 0000000..81d28a4 --- /dev/null +++ b/source/core/StarUdp.hpp @@ -0,0 +1,41 @@ +#ifndef STAR_UDP_HPP +#define STAR_UDP_HPP + +#include "StarSocket.hpp" + +namespace Star { + +STAR_CLASS(UdpSocket); +STAR_CLASS(UdpServer); + +// A Good default assumption for a maximum size of a UDP datagram without +// fragmentation +unsigned const MaxUdpData = 1460; + +class UdpSocket : public Socket { +public: + UdpSocket(NetworkMode networkMode); + + size_t receive(HostAddressWithPort* address, char* data, size_t size); + size_t send(HostAddressWithPort const& address, char const* data, size_t size); +}; + +class UdpServer { +public: + UdpServer(HostAddressWithPort const& address); + ~UdpServer(); + + void close(); + bool isListening() const; + + size_t receive(HostAddressWithPort* address, char* data, size_t size, unsigned timeout); + size_t send(HostAddressWithPort const& address, char const* data, size_t size); + +private: + HostAddressWithPort const m_hostAddress; + UdpSocketPtr m_listenSocket; +}; + +} + +#endif diff --git a/source/core/StarUnicode.cpp b/source/core/StarUnicode.cpp new file mode 100644 index 0000000..008df96 --- /dev/null +++ b/source/core/StarUnicode.cpp @@ -0,0 +1,273 @@ +#include "StarUnicode.hpp" +#include "StarEncode.hpp" + +namespace Star { + +void throwInvalidUtf8Sequence() { + throw UnicodeException("Invalid UTF-8 code unit sequence in utf8Length"); +} + +void throwMissingUtf8End() { + throw UnicodeException("UTF-8 string missing trailing code units in utf8Length"); +} + +void throwInvalidUtf32CodePoint(Utf32Type val) { + throw UnicodeException::format("Invalid UTF-32 code point %s encountered while trying to encode UTF-8", (int32_t)val); +} + +size_t utf8Length(const Utf8Type* utf8, size_t remain) { + bool stopOnNull = remain == NPos; + size_t length = 0; + + while (true) { + if (remain == 0) + break; + + if (stopOnNull && utf8[0] == 0) + break; + + if ((utf8[0] & 0x80) == 0x00) { + ++length; + ++utf8; + --remain; + continue; + } + + if (remain == 1) + throwMissingUtf8End(); + + if ((utf8[0] & 0xe0) == 0xc0 && (utf8[1] & 0xc0) == 0x80) { + if (((utf8[0] & 0x1fL) << 6) >= 0x00000080L) { + ++length; + utf8 += 2; + remain -= 2; + continue; + } else { + throwInvalidUtf8Sequence(); + } + } + + if (remain == 2) + throwMissingUtf8End(); + + if ((utf8[0] & 0xf0) == 0xe0 && (utf8[1] & 0xc0) == 0x80 && (utf8[2] & 0xc0) == 0x80) { + if ((((utf8[0] & 0x0fL) << 12) | ((utf8[1] & 0x3fL) << 6)) >= 0x00000800L) { + ++length; + utf8 += 3; + remain -= 3; + continue; + } else { + throwInvalidUtf8Sequence(); + } + } + + if (remain == 3) + throwMissingUtf8End(); + + if ((utf8[0] & 0xf8) == 0xf0 && (utf8[1] & 0xc0) == 0x80 && (utf8[2] & 0xc0) == 0x80 && (utf8[3] & 0xc0) == 0x80) { + if ((((utf8[0] & 0x07L) << 18) | ((utf8[1] & 0x3fL) << 12)) >= 0x00010000L) { + ++length; + utf8 += 4; + remain -= 4; + continue; + } else { + throwInvalidUtf8Sequence(); + } + } else { + throwInvalidUtf8Sequence(); + } + } + + return length; +} + +size_t utf8DecodeChar(const Utf8Type* utf8, Utf32Type* utf32, size_t remain) { + const Utf8Type* start = utf8; + bool stopOnNull = remain == NPos; + + while (true) { + if (remain == 0) + break; + + if (stopOnNull && utf8[0] == 0) + break; + + if ((utf8[0] & 0x80) == 0x00) { + *utf32 = utf8[0]; + return utf8 - start + 1; + } + + if (remain == 1) + throwMissingUtf8End(); + + if ((utf8[0] & 0xe0) == 0xc0 && (utf8[1] & 0xc0) == 0x80) { + *utf32 = ((utf8[0] & 0x1fL) << 6) | ((utf8[1] & 0x3fL) << 0); + if (*utf32 >= 0x00000080L) + return utf8 - start + 2; + else + throwInvalidUtf8Sequence(); + } + + if (remain == 2) + throwMissingUtf8End(); + + if ((utf8[0] & 0xf0) == 0xe0 && (utf8[1] & 0xc0) == 0x80 && (utf8[2] & 0xc0) == 0x80) { + *utf32 = ((utf8[0] & 0x0fL) << 12) | ((utf8[1] & 0x3fL) << 6) | ((utf8[2] & 0x3fL) << 0); + if (*utf32 >= 0x00000800L) + return utf8 - start + 3; + else + throwInvalidUtf8Sequence(); + } + + if (remain == 3) + throwMissingUtf8End(); + + if ((utf8[0] & 0xf8) == 0xf0 && (utf8[1] & 0xc0) == 0x80 && (utf8[2] & 0xc0) == 0x80 && (utf8[3] & 0xc0) == 0x80) { + *utf32 = + ((utf8[0] & 0x07L) << 18) | ((utf8[1] & 0x3fL) << 12) | ((utf8[2] & 0x3fL) << 6) | ((utf8[3] & 0x3fL) << 0); + if (*utf32 >= 0x00010000L) + return utf8 - start + 4; + else + throwInvalidUtf8Sequence(); + } else { + throwInvalidUtf8Sequence(); + } + } + + return utf8 - start; +} + +size_t utf8EncodeChar(Utf8Type* utf8, Utf32Type utf32, size_t len) { + if (utf32 > 0x10FFFFu) + throwInvalidUtf32CodePoint(utf32); + + if (utf32 <= 0x0000007fL) { + if (len < 1) + return 0; + + utf8[0] = utf32; + return 1; + } else if (utf32 <= 0x000007ffL) { + if (len < 2) + return 0; + + utf8[0] = 0xc0 | ((utf32 >> 6) & 0x1f); + utf8[1] = 0x80 | ((utf32 >> 0) & 0x3f); + + return 2; + } else if (utf32 <= 0x0000ffffL) { + if (len < 3) + return 0; + + utf8[0] = 0xe0 | ((utf32 >> 12) & 0x0f); + utf8[1] = 0x80 | ((utf32 >> 6) & 0x3f); + utf8[2] = 0x80 | ((utf32 >> 0) & 0x3f); + + return 3; + } else { + if (len < 4) + return 0; + + utf8[0] = 0xf0 | ((utf32 >> 18) & 0x07); + utf8[1] = 0x80 | ((utf32 >> 12) & 0x3f); + utf8[2] = 0x80 | ((utf32 >> 6) & 0x3f); + utf8[3] = 0x80 | ((utf32 >> 0) & 0x3f); + + return 4; + } +} + +static const char32_t MIN_LEAD = 0xd800; +static const char32_t MAX_LEAD = 0xdbff; +static const char32_t MIN_TRAIL = 0xdc00; +static const char32_t MAX_TRAIL = 0xdfff; +static const char32_t SURR_MASK = 0x3ff; +static const char32_t MIN_PAIR = 0x10000; +static const char32_t MAX_CODEPOINT = 0x10ffff; + +Utf32Type hexStringToUtf32(std::string const& codepoint, Maybe<Utf32Type> previousCodepoint) { + bool continuation = false; + if (previousCodepoint && isUtf16LeadSurrogate(*previousCodepoint)) { + continuation = true; + } + + auto hexBytes = hexDecode(codepoint); + if (hexBytes.size() < sizeof(Utf32Type)) { + ByteArray newHexBytes{(size_t)(sizeof(Utf32Type) - hexBytes.size()), (char)'\0'}; + newHexBytes.append(hexBytes); + hexBytes = newHexBytes; + } + + if (hexBytes.size() > sizeof(Utf32Type)) + throw UnicodeException("Codepoint size is too big in parseUnicodeCodepoint"); + + auto res = fromBigEndian(*(Utf32Type*)hexBytes.ptr()); + + if (continuation) { + res = utf32FromUtf16SurrogatePair(*previousCodepoint, res); + } + + return res; +} + +std::string hexStringFromUtf32(Utf32Type character) { + if (character > MAX_CODEPOINT) + throw UnicodeException("Codepoint too big in hexStringFromUtf32"); + Utf32Type lead; + Maybe<Utf32Type> trail; + tie(lead, trail) = utf32ToUtf16SurrogatePair(character); + + char16_t leadOut = toBigEndian((char16_t)lead); + auto leadHex = hexEncode(reinterpret_cast<char*>(&leadOut), sizeof(leadOut)).takeUtf8(); + + starAssert(leadHex.size() == 4); + + if (!trail) + return leadHex; + + char16_t trailOut = toBigEndian((char16_t)*trail); + auto trailHex = hexEncode(reinterpret_cast<char*>(&trailOut), sizeof(trailOut)); + + starAssert(trailHex.size() == 4); + + return (leadHex + trailHex).takeUtf8(); +} + +bool isUtf16LeadSurrogate(Utf32Type codepoint) { + return codepoint >= MIN_LEAD && codepoint <= MAX_LEAD; +} + +bool isUtf16TrailSurrogate(Utf32Type codepoint) { + return codepoint >= MIN_TRAIL && codepoint <= MAX_TRAIL; +} + +Utf32Type utf32FromUtf16SurrogatePair(Utf32Type lead, Utf32Type trail) { + if (!isUtf16LeadSurrogate(lead)) + throw UnicodeException("Invalid lead surrogate passed to utf32FromUtf16SurrogatePair"); + if (!isUtf16TrailSurrogate(trail)) + throw UnicodeException("Invalid trail surrogate passed to utf32FromUtf16SurrogatePair"); + + lead -= MIN_LEAD; + trail -= MIN_TRAIL; + + Utf32Type codepoint = (lead << 10) + trail + MIN_PAIR; + + return codepoint; +} + +pair<Utf32Type, Maybe<Utf32Type>> utf32ToUtf16SurrogatePair(Utf32Type codepoint) { + if (codepoint >= MIN_PAIR) { + codepoint -= MIN_PAIR; + Utf32Type lead = (codepoint >> 10) + MIN_LEAD; + Utf32Type trail = (codepoint & SURR_MASK) + MIN_TRAIL; + + if (!isUtf16LeadSurrogate(lead)) + throw UnicodeException("Invalid codepoint passed to utf32ToUtf16SurrogatePair"); + + return {lead, trail}; + } + + return {codepoint, {}}; +} + +} diff --git a/source/core/StarUnicode.hpp b/source/core/StarUnicode.hpp new file mode 100644 index 0000000..845259f --- /dev/null +++ b/source/core/StarUnicode.hpp @@ -0,0 +1,229 @@ +#ifndef STAR_UTF8_HPP +#define STAR_UTF8_HPP + +#include "StarByteArray.hpp" +#include "StarMaybe.hpp" + +namespace Star { + +STAR_EXCEPTION(UnicodeException, StarException); + +typedef char Utf8Type; +typedef char32_t Utf32Type; + +#define STAR_UTF32_REPLACEMENT_CHAR 0x000000b7L + +void throwInvalidUtf8Sequence(); +void throwMissingUtf8End(); +void throwInvalidUtf32CodePoint(Utf32Type val); + +// If passed NPos as a size, assumes modified UTF-8 and stops on NULL byte. +// Otherwise, ignores NULL. +size_t utf8Length(Utf8Type const* utf8, size_t size = NPos); +// Encode up to six utf8 bytes into a utf32 character. If passed NPos as len, +// assumes modified UTF-8 and stops on NULL, otherwise ignores. +size_t utf8DecodeChar(Utf8Type const* utf8, Utf32Type* utf32, size_t len = NPos); +// Encode single utf32 char into up to 6 utf8 characters. +size_t utf8EncodeChar(Utf8Type* utf8, Utf32Type utf32, size_t len = 6); + +Utf32Type hexStringToUtf32(std::string const& codepoint, Maybe<Utf32Type> previousCodepoint = {}); +std::string hexStringFromUtf32(Utf32Type character); + +bool isUtf16LeadSurrogate(Utf32Type codepoint); +bool isUtf16TrailSurrogate(Utf32Type codepoint); + +Utf32Type utf32FromUtf16SurrogatePair(Utf32Type lead, Utf32Type trail); +pair<Utf32Type, Maybe<Utf32Type>> utf32ToUtf16SurrogatePair(Utf32Type codepoint); + +// Bidirectional iterator that can make utf8 appear as utf32 +template <class BaseIterator, class U32Type = Utf32Type> +class U8ToU32Iterator { +public: + typedef ptrdiff_t difference_type; + typedef U32Type value_type; + typedef U32Type* pointer; + typedef U32Type& reference; + typedef std::bidirectional_iterator_tag iterator_category; + + U8ToU32Iterator() : m_position(), m_value(pending_read) {} + + U8ToU32Iterator(BaseIterator b) : m_position(b), m_value(pending_read) {} + + U32Type const& operator*() const { + if (m_value == pending_read) + extract_current(); + return m_value; + } + + U8ToU32Iterator const& operator++() { + increment(); + return *this; + } + + U8ToU32Iterator operator++(int) { + U8ToU32Iterator clone(*this); + increment(); + return clone; + } + + U8ToU32Iterator const& operator--() { + decrement(); + return *this; + } + + U8ToU32Iterator operator--(int) { + U8ToU32Iterator clone(*this); + decrement(); + return clone; + } + + bool operator==(U8ToU32Iterator const& that) const { + return equal(that); + } + + bool operator!=(U8ToU32Iterator const& that) const { + return !equal(that); + } + +private: + // special values for pending iterator reads: + static U32Type const pending_read = 0xffffffffu; + + static void invalid_sequence() { + throwInvalidUtf8Sequence(); + } + + static unsigned utf8_byte_count(Utf8Type c) { + // if the most significant bit with a zero in it is in position + // 8-N then there are N bytes in this UTF-8 sequence: + uint8_t mask = 0x80u; + unsigned result = 0; + while (c & mask) { + ++result; + mask >>= 1; + } + return (result == 0) ? 1 : ((result > 4) ? 4 : result); + } + + static unsigned utf8_trailing_byte_count(Utf8Type c) { + return utf8_byte_count(c) - 1; + } + + void increment() { + // skip high surrogate first if there is one: + unsigned c = utf8_byte_count(*m_position); + std::advance(m_position, c); + m_value = pending_read; + } + + void decrement() { + // Keep backtracking until we don't have a trailing character: + unsigned count = 0; + while (((uint8_t) * --m_position & 0xC0u) == 0x80u) + ++count; + // now check that the sequence was valid: + if (count != utf8_trailing_byte_count(*m_position)) + invalid_sequence(); + m_value = pending_read; + } + + bool equal(const U8ToU32Iterator& that) const { + return m_position == that.m_position; + } + + void extract_current() const { + m_value = static_cast<Utf8Type>(*m_position); + // we must not have a continuation character: + if (((uint8_t)m_value & 0xC0u) == 0x80u) + invalid_sequence(); + // see how many extra byts we have: + unsigned extra = utf8_trailing_byte_count(*m_position); + // extract the extra bits, 6 from each extra byte: + BaseIterator next(m_position); + for (unsigned c = 0; c < extra; ++c) { + ++next; + m_value <<= 6; + auto entry = static_cast<uint8_t>(*next); + if ((c > 0) && ((entry & 0xC0u) != 0x80u)) + invalid_sequence(); + m_value += entry & 0x3Fu; + } + // we now need to remove a few of the leftmost bits, but how many depends + // upon how many extra bytes we've extracted: + static const Utf32Type masks[4] = { + 0x7Fu, 0x7FFu, 0xFFFFu, 0x1FFFFFu, + }; + m_value &= masks[extra]; + // check the result: + if ((uint32_t)m_value > (uint32_t)0x10FFFFu) + invalid_sequence(); + } + + BaseIterator m_position; + mutable U32Type m_value; +}; + +// Output iterator +template <class BaseIterator, class U32Type = Utf32Type> +class Utf8OutputIterator { +public: + typedef void difference_type; + typedef void value_type; + typedef U32Type* pointer; + typedef U32Type& reference; + + Utf8OutputIterator(const BaseIterator& b) : m_position(b) {} + Utf8OutputIterator(const Utf8OutputIterator& that) : m_position(that.m_position) {} + Utf8OutputIterator& operator=(const Utf8OutputIterator& that) { + m_position = that.m_position; + return *this; + } + + const Utf8OutputIterator& operator*() const { + return *this; + } + + void operator=(U32Type val) const { + push(val); + } + + Utf8OutputIterator& operator++() { + return *this; + } + + Utf8OutputIterator& operator++(int) { + return *this; + } + +private: + static void invalid_utf32_code_point(U32Type val) { + throwInvalidUtf32CodePoint(val); + } + + void push(U32Type c) const { + if (c > 0x10FFFFu) + invalid_utf32_code_point(c); + + if ((uint32_t)c < 0x80u) { + *m_position++ = static_cast<Utf8Type>((uint32_t)c); + } else if ((uint32_t)c < 0x800u) { + *m_position++ = static_cast<Utf8Type>(0xC0u + ((uint32_t)c >> 6)); + *m_position++ = static_cast<Utf8Type>(0x80u + ((uint32_t)c & 0x3Fu)); + } else if ((uint32_t)c < 0x10000u) { + *m_position++ = static_cast<Utf8Type>(0xE0u + ((uint32_t)c >> 12)); + *m_position++ = static_cast<Utf8Type>(0x80u + (((uint32_t)c >> 6) & 0x3Fu)); + *m_position++ = static_cast<Utf8Type>(0x80u + ((uint32_t)c & 0x3Fu)); + } else { + *m_position++ = static_cast<Utf8Type>(0xF0u + ((uint32_t)c >> 18)); + *m_position++ = static_cast<Utf8Type>(0x80u + (((uint32_t)c >> 12) & 0x3Fu)); + *m_position++ = static_cast<Utf8Type>(0x80u + (((uint32_t)c >> 6) & 0x3Fu)); + *m_position++ = static_cast<Utf8Type>(0x80u + ((uint32_t)c & 0x3Fu)); + } + } + + mutable BaseIterator m_position; +}; + +} + +#endif diff --git a/source/core/StarUuid.cpp b/source/core/StarUuid.cpp new file mode 100644 index 0000000..11121b6 --- /dev/null +++ b/source/core/StarUuid.cpp @@ -0,0 +1,72 @@ +#include "StarUuid.hpp" +#include "StarRandom.hpp" +#include "StarFormat.hpp" +#include "StarEncode.hpp" + +namespace Star { + +Uuid::Uuid() : Uuid(Random::randBytes(UuidSize)) {} + +Uuid::Uuid(ByteArray const& bytes) { + if (bytes.size() != UuidSize) + throw UuidException(strf("Size mismatch in reading Uuid from ByteArray: %s vs %s", bytes.size(), UuidSize)); + + bytes.copyTo(m_data.ptr(), UuidSize); +} + +Uuid::Uuid(String const& hex) : Uuid(hexDecode(hex)) {} + +char const* Uuid::ptr() const { + return m_data.ptr(); +} + +ByteArray Uuid::bytes() const { + return ByteArray(m_data.ptr(), UuidSize); +} + +String Uuid::hex() const { + return hexEncode(m_data.ptr(), UuidSize); +} + +bool Uuid::operator==(Uuid const& u) const { + return m_data == u.m_data; +} + +bool Uuid::operator!=(Uuid const& u) const { + return m_data != u.m_data; +} + +bool Uuid::operator<(Uuid const& u) const { + return m_data < u.m_data; +} + +bool Uuid::operator<=(Uuid const& u) const { + return m_data <= u.m_data; +} + +bool Uuid::operator>(Uuid const& u) const { + return m_data > u.m_data; +} + +bool Uuid::operator>=(Uuid const& u) const { + return m_data >= u.m_data; +} + +size_t hash<Uuid>::operator()(Uuid const& u) const { + size_t hashval = 0; + for (size_t i = 0; i < UuidSize; ++i) + hashCombine(hashval, u.ptr()[i]); + return hashval; +} + +DataStream& operator>>(DataStream& ds, Uuid& uuid) { + uuid = Uuid(ds.readBytes(UuidSize)); + return ds; +} + +DataStream& operator<<(DataStream& ds, Uuid const& uuid) { + ds.writeData(uuid.ptr(), UuidSize); + return ds; +} + +} diff --git a/source/core/StarUuid.hpp b/source/core/StarUuid.hpp new file mode 100644 index 0000000..3b540fc --- /dev/null +++ b/source/core/StarUuid.hpp @@ -0,0 +1,44 @@ +#ifndef STAR_UUID_HPP +#define STAR_UUID_HPP + +#include "StarArray.hpp" +#include "StarDataStream.hpp" + +namespace Star { + +STAR_EXCEPTION(UuidException, StarException); + +size_t const UuidSize = 16; + +class Uuid { +public: + Uuid(); + explicit Uuid(ByteArray const& bytes); + explicit Uuid(String const& hex); + + char const* ptr() const; + ByteArray bytes() const; + String hex() const; + + bool operator==(Uuid const& u) const; + bool operator!=(Uuid const& u) const; + bool operator<(Uuid const& u) const; + bool operator<=(Uuid const& u) const; + bool operator>(Uuid const& u) const; + bool operator>=(Uuid const& u) const; + +private: + Array<char, UuidSize> m_data; +}; + +template <> +struct hash<Uuid> { + size_t operator()(Uuid const& u) const; +}; + +DataStream& operator>>(DataStream& ds, Uuid& uuid); +DataStream& operator<<(DataStream& ds, Uuid const& uuid); + +} + +#endif diff --git a/source/core/StarVariant.hpp b/source/core/StarVariant.hpp new file mode 100644 index 0000000..fe45a6d --- /dev/null +++ b/source/core/StarVariant.hpp @@ -0,0 +1,927 @@ +#ifndef STAR_VARIANT_HPP +#define STAR_VARIANT_HPP + +#include <type_traits> +#include <utility> + +#include "StarAlgorithm.hpp" +#include "StarMaybe.hpp" + +namespace Star { + +STAR_EXCEPTION(BadVariantCast, StarException); +STAR_EXCEPTION(BadVariantType, StarException); + +typedef uint8_t VariantTypeIndex; +VariantTypeIndex const InvalidVariantType = 255; + +namespace detail { + template <typename T, typename... Args> + struct HasType; + + template <typename T> + struct HasType<T> : std::false_type {}; + + template <typename T, typename Head, typename... Args> + struct HasType<T, Head, Args...> { + static constexpr bool value = std::is_same<T, Head>::value || HasType<T, Args...>::value; + }; + + template <typename... Args> + struct IsNothrowMoveConstructible; + + template <> + struct IsNothrowMoveConstructible<> : std::true_type {}; + + template <typename Head, typename... Args> + struct IsNothrowMoveConstructible<Head, Args...> { + static constexpr bool value = std::is_nothrow_move_constructible<Head>::value && IsNothrowMoveConstructible<Args...>::value; + }; + + template <typename... Args> + struct IsNothrowMoveAssignable; + + template <> + struct IsNothrowMoveAssignable<> : std::true_type {}; + + template <typename Head, typename... Args> + struct IsNothrowMoveAssignable<Head, Args...> { + static constexpr bool value = std::is_nothrow_move_assignable<Head>::value && IsNothrowMoveAssignable<Args...>::value; + }; +} + +// Stack based variant type container that can be inhabited by one of a limited +// number of types. +template <typename FirstType, typename... RestTypes> +class Variant { +public: + template <typename T> + using ValidateType = typename std::enable_if<detail::HasType<T, FirstType, RestTypes...>::value, void>::type; + + template <typename T, typename = ValidateType<T>> + static constexpr VariantTypeIndex typeIndexOf(); + + // If the first type has a default constructor, constructs an Variant which + // contains a default constructed value of that type. + Variant(); + + template <typename T, typename = ValidateType<T>> + Variant(T const& x); + template <typename T, typename = ValidateType<T>> + Variant(T&& x); + + Variant(Variant const& x); + Variant(Variant&& x) noexcept(detail::IsNothrowMoveConstructible<FirstType, RestTypes...>::value); + + ~Variant(); + + // Implementations of operator= may invalidate the Variant if the copy or + // move constructor of the assigned value throws. + Variant& operator=(Variant const& x); + Variant& operator=(Variant&& x) noexcept(detail::IsNothrowMoveAssignable<FirstType, RestTypes...>::value); + template <typename T, typename = ValidateType<T>> + Variant& operator=(T const& x); + template <typename T, typename = ValidateType<T>> + Variant& operator=(T&& x); + + // Returns true if this Variant contains the given type. + template <typename T, typename = ValidateType<T>> + bool is() const; + + // get throws BadVariantCast on bad casts + + template <typename T, typename = ValidateType<T>> + T const& get() const; + + template <typename T, typename = ValidateType<T>> + T& get(); + + template <typename T, typename = ValidateType<T>> + Maybe<T> maybe() const; + + // ptr() does not throw if this Variant does not hold the given type, instead + // simply returns nullptr. + + template <typename T, typename = ValidateType<T>> + T const* ptr() const; + + template <typename T, typename = ValidateType<T>> + T* ptr(); + + // Calls the given function with the type currently being held, and returns + // the value returned by that function. Will throw if this Variant has been + // invalidated. + template <typename Function> + decltype(auto) call(Function&& function); + template <typename Function> + decltype(auto) call(Function&& function) const; + + // Returns an index for the held type, which can be passed into makeType to + // make this Variant hold a specific type. Returns InvalidVariantType if + // invalidated. + VariantTypeIndex typeIndex() const; + + // Make this Variant hold a new default constructed type of the given type + // index. Can only be used if every alternative type has a default + // constructor. Throws if given an out of range type index or + // InvalidVariantType. + void makeType(VariantTypeIndex typeIndex); + + // True if this Variant has been invalidated. If the copy or move + // constructor of a type throws an exception during assignment, there is no + // *good* way to ensure that the Variant has a valid type, so it may become + // invalidated. It is not possible to directly construct an invalidated + // Variant. + bool invalid() const; + + // Requires that every type included in this Variant has operator== + bool operator==(Variant const& x) const; + bool operator!=(Variant const& x) const; + + // Requires that every type included in this Variant has operator< + bool operator<(Variant const& x) const; + + template <typename T, typename = ValidateType<T>> + bool operator==(T const& x) const; + template <typename T, typename = ValidateType<T>> + bool operator!=(T const& x) const; + template <typename T, typename = ValidateType<T>> + bool operator<(T const& x) const; + +private: + template <typename MatchType, VariantTypeIndex Index, typename... Rest> + struct LookupTypeIndex; + + template <typename MatchType, VariantTypeIndex Index> + struct LookupTypeIndex<MatchType, Index> { + static VariantTypeIndex const value = InvalidVariantType; + }; + + template <typename MatchType, VariantTypeIndex Index, typename Head, typename... Rest> + struct LookupTypeIndex<MatchType, Index, Head, Rest...> { + static VariantTypeIndex const value = std::is_same<MatchType, Head>::value ? Index : LookupTypeIndex<MatchType, Index + 1, Rest...>::value; + }; + + template <typename MatchType> + struct TypeIndex { + static VariantTypeIndex const value = LookupTypeIndex<MatchType, 0, FirstType, RestTypes...>::value; + }; + + void destruct(); + + template <typename T> + void assign(T&& x); + + template <typename Function, typename T> + decltype(auto) doCall(Function&& function); + template <typename Function, typename T1, typename T2, typename... TL> + decltype(auto) doCall(Function&& function); + + template <typename Function, typename T> + decltype(auto) doCall(Function&& function) const; + template <typename Function, typename T1, typename T2, typename... TL> + decltype(auto) doCall(Function&& function) const; + + template <typename First> + void doMakeType(VariantTypeIndex); + template <typename First, typename Second, typename... Rest> + void doMakeType(VariantTypeIndex typeIndex); + + typename std::aligned_union<0, FirstType, RestTypes...>::type m_buffer; + VariantTypeIndex m_typeIndex = InvalidVariantType; +}; + +// A version of Variant that has always has a default "empty" state, useful +// when there is no good default type for a Variant but it needs to be default +// constructed, and is slightly more convenient than Maybe<Variant<Types...>>. +template <typename... Types> +class MVariant { +public: + template <typename T> + using ValidateType = typename std::enable_if<detail::HasType<T, Types...>::value, void>::type; + + template <typename T, typename = ValidateType<T>> + static constexpr VariantTypeIndex typeIndexOf(); + + MVariant(); + MVariant(MVariant const& x); + MVariant(MVariant&& x); + + template <typename T, typename = ValidateType<T>> + MVariant(T const& x); + template <typename T, typename = ValidateType<T>> + MVariant(T&& x); + + MVariant(Variant<Types...> const& x); + MVariant(Variant<Types...>&& x); + + ~MVariant(); + + // MVariant::operator= will never invalidate the MVariant, instead it will + // just become empty. + MVariant& operator=(MVariant const& x); + MVariant& operator=(MVariant&& x); + + template <typename T, typename = ValidateType<T>> + MVariant& operator=(T const& x); + template <typename T, typename = ValidateType<T>> + MVariant& operator=(T&& x); + + MVariant& operator=(Variant<Types...> const& x); + MVariant& operator=(Variant<Types...>&& x); + + // Requires that every type included in this MVariant has operator== + bool operator==(MVariant const& x) const; + bool operator!=(MVariant const& x) const; + + // Requires that every type included in this MVariant has operator< + bool operator<(MVariant const& x) const; + + template <typename T, typename = ValidateType<T>> + bool operator==(T const& x) const; + template <typename T, typename = ValidateType<T>> + bool operator!=(T const& x) const; + template <typename T, typename = ValidateType<T>> + bool operator<(T const& x) const; + + // get throws BadVariantCast on bad casts + + template <typename T, typename = ValidateType<T>> + T const& get() const; + + template <typename T, typename = ValidateType<T>> + T& get(); + + // maybe() and ptr() do not throw if this MVariant does not hold the given + // type, instead simply returns Nothing / nullptr. + + template <typename T, typename = ValidateType<T>> + Maybe<T> maybe() const; + + template <typename T, typename = ValidateType<T>> + T const* ptr() const; + + template <typename T, typename = ValidateType<T>> + T* ptr(); + + template <typename T, typename = ValidateType<T>> + bool is() const; + + // Takes the given value out and leaves this empty + template <typename T, typename = ValidateType<T>> + T take(); + + // Returns a Variant of all the allowed types if non-empty, throws + // BadVariantCast if empty. + Variant<Types...> value() const; + + // Moves the contents of this MVariant into the given Variant if non-empty, + // throws BadVariantCast if empty. + Variant<Types...> takeValue(); + + bool empty() const; + void reset(); + + // Equivalent to !empty() + explicit operator bool() const; + + // If this MVariant holds a type, calls the given function with the type + // being held. If nothing is currently held, the function is not called. + template <typename Function> + void call(Function&& function); + + template <typename Function> + void call(Function&& function) const; + + // Returns an index for the held type, which can be passed into makeType to + // make this MVariant hold a specific type. Types are always indexed in the + // order they are specified starting from 1. A type index of 0 indicates an + // empty MVariant. + VariantTypeIndex typeIndex() const; + + // Make this MVariant hold a new default constructed type of the given type + // index. Can only be used if every alternative type has a default + // constructor. + void makeType(VariantTypeIndex typeIndex); + +private: + struct MVariantEmpty { + bool operator==(MVariantEmpty const& rhs) const; + bool operator<(MVariantEmpty const& rhs) const; + }; + + template <typename Function> + struct RefCaller { + Function&& function; + + RefCaller(Function&& function); + + void operator()(MVariantEmpty& empty); + + template <typename T> + void operator()(T& t); + }; + + template <typename Function> + struct ConstRefCaller { + Function&& function; + + ConstRefCaller(Function&& function); + + void operator()(MVariantEmpty const& empty); + + template <typename T> + void operator()(T const& t); + }; + + Variant<MVariantEmpty, Types...> m_variant; +}; + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +constexpr VariantTypeIndex Variant<FirstType, RestTypes...>::typeIndexOf() { + return TypeIndex<T>::value; +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>::Variant() + : Variant(FirstType()) {} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +Variant<FirstType, RestTypes...>::Variant(T const& x) { + assign(x); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +Variant<FirstType, RestTypes...>::Variant(T&& x) { + assign(forward<T>(x)); +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>::Variant(Variant const& x) { + x.call([this](auto const& t) { + assign(t); + }); +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>::Variant(Variant&& x) + noexcept(detail::IsNothrowMoveConstructible<FirstType, RestTypes...>::value) { + x.call([this](auto& t) { + assign(move(t)); + }); +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>::~Variant() { + destruct(); +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>& Variant<FirstType, RestTypes...>::operator=(Variant const& x) { + if (&x == this) + return *this; + + x.call([this](auto const& t) { + assign(t); + }); + + return *this; +} + +template <typename FirstType, typename... RestTypes> +Variant<FirstType, RestTypes...>& Variant<FirstType, RestTypes...>::operator=(Variant&& x) + noexcept(detail::IsNothrowMoveAssignable<FirstType, RestTypes...>::value) { + if (&x == this) + return *this; + + x.call([this](auto& t) { + assign(move(t)); + }); + + return *this; +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +Variant<FirstType, RestTypes...>& Variant<FirstType, RestTypes...>::operator=(T const& x) { + assign(x); + return *this; +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +Variant<FirstType, RestTypes...>& Variant<FirstType, RestTypes...>::operator=(T&& x) { + assign(forward<T>(x)); + return *this; +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +T const& Variant<FirstType, RestTypes...>::get() const { + if (!is<T>()) + throw BadVariantCast(); + return *(T*)(&m_buffer); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +T& Variant<FirstType, RestTypes...>::get() { + if (!is<T>()) + throw BadVariantCast(); + return *(T*)(&m_buffer); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +Maybe<T> Variant<FirstType, RestTypes...>::maybe() const { + if (!is<T>()) + return {}; + return *(T*)(&m_buffer); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +T const* Variant<FirstType, RestTypes...>::ptr() const { + if (!is<T>()) + return nullptr; + return (T*)(&m_buffer); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +T* Variant<FirstType, RestTypes...>::ptr() { + if (!is<T>()) + return nullptr; + return (T*)(&m_buffer); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +bool Variant<FirstType, RestTypes...>::is() const { + return m_typeIndex == TypeIndex<T>::value; +} + +template <typename FirstType, typename... RestTypes> +template <typename Function> +decltype(auto) Variant<FirstType, RestTypes...>::call(Function&& function) { + return doCall<Function, FirstType, RestTypes...>(forward<Function>(function)); +} + +template <typename FirstType, typename... RestTypes> +template <typename Function> +decltype(auto) Variant<FirstType, RestTypes...>::call(Function&& function) const { + return doCall<Function, FirstType, RestTypes...>(forward<Function>(function)); +} + +template <typename FirstType, typename... RestTypes> +VariantTypeIndex Variant<FirstType, RestTypes...>::typeIndex() const { + return m_typeIndex; +} + +template <typename FirstType, typename... RestTypes> +void Variant<FirstType, RestTypes...>::makeType(VariantTypeIndex typeIndex) { + return doMakeType<FirstType, RestTypes...>(typeIndex); +} + +template <typename FirstType, typename... RestTypes> +bool Variant<FirstType, RestTypes...>::invalid() const { + return m_typeIndex == InvalidVariantType; +} + +template <typename FirstType, typename... RestTypes> +bool Variant<FirstType, RestTypes...>::operator==(Variant const& x) const { + if (this == &x) { + return true; + } else if (typeIndex() != x.typeIndex()) { + return false; + } else { + return call([&x](auto const& t) { + typedef typename std::decay<decltype(t)>::type T; + return t == x.template get<T>(); + }); + } +} + +template <typename FirstType, typename... RestTypes> +bool Variant<FirstType, RestTypes...>::operator!=(Variant const& x) const { + return !operator==(x); +} + +template <typename FirstType, typename... RestTypes> +bool Variant<FirstType, RestTypes...>::operator<(Variant const& x) const { + if (this == &x) { + return false; + } else { + auto sti = typeIndex(); + auto xti = x.typeIndex(); + if (sti != xti) { + return sti < xti; + } else { + return call([&x](auto const& t) { + typedef typename std::decay<decltype(t)>::type T; + return t < x.template get<T>(); + }); + } + } +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +bool Variant<FirstType, RestTypes...>::operator==(T const& x) const { + if (auto p = ptr<T>()) + return *p == x; + return false; +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +bool Variant<FirstType, RestTypes...>::operator!=(T const& x) const { + return !operator==(x); +} + +template <typename FirstType, typename... RestTypes> +template <typename T, typename> +bool Variant<FirstType, RestTypes...>::operator<(T const& x) const { + if (auto p = ptr<T>()) + return *p == x; + return m_typeIndex < TypeIndex<T>::value; +} + +template <typename FirstType, typename... RestTypes> +void Variant<FirstType, RestTypes...>::destruct() { + if (m_typeIndex != InvalidVariantType) { + try { + call([](auto& t) { + typedef typename std::decay<decltype(t)>::type T; + t.~T(); + }); + m_typeIndex = InvalidVariantType; + } catch (...) { + m_typeIndex = InvalidVariantType; + throw; + } + } +} + +template <typename FirstType, typename... RestTypes> +template <typename T> +void Variant<FirstType, RestTypes...>::assign(T&& x) { + typedef typename std::decay<T>::type AssignType; + if (auto p = ptr<AssignType>()) { + *p = forward<T>(x); + } else { + destruct(); + new (&m_buffer) AssignType(forward<T>(x)); + m_typeIndex = TypeIndex<AssignType>::value; + } +} + +template <typename FirstType, typename... RestTypes> +template <typename Function, typename T> +decltype(auto) Variant<FirstType, RestTypes...>::doCall(Function&& function) { + if (T* p = ptr<T>()) + return function(*p); + else + throw BadVariantType(); +} + +template <typename FirstType, typename... RestTypes> +template <typename Function, typename T1, typename T2, typename... TL> +decltype(auto) Variant<FirstType, RestTypes...>::doCall(Function&& function) { + if (T1* p = ptr<T1>()) + return function(*p); + else + return doCall<Function, T2, TL...>(forward<Function>(function)); +} + +template <typename FirstType, typename... RestTypes> +template <typename Function, typename T> +decltype(auto) Variant<FirstType, RestTypes...>::doCall(Function&& function) const { + if (T const* p = ptr<T>()) + return function(*p); + else + throw BadVariantType(); +} + +template <typename FirstType, typename... RestTypes> +template <typename Function, typename T1, typename T2, typename... TL> +decltype(auto) Variant<FirstType, RestTypes...>::doCall(Function&& function) const { + if (T1 const* p = ptr<T1>()) + return function(*p); + else + return doCall<Function, T2, TL...>(forward<Function>(function)); +} + +template <typename FirstType, typename... RestTypes> +template <typename First> +void Variant<FirstType, RestTypes...>::doMakeType(VariantTypeIndex typeIndex) { + if (typeIndex == 0) + *this = First(); + else + throw BadVariantType(); +} + +template <typename FirstType, typename... RestTypes> +template <typename First, typename Second, typename... Rest> +void Variant<FirstType, RestTypes...>::doMakeType(VariantTypeIndex typeIndex) { + if (typeIndex == 0) + *this = First(); + else + return doMakeType<Second, Rest...>(typeIndex - 1); +} + +template <typename... Types> +template <typename T, typename> +constexpr VariantTypeIndex MVariant<Types...>::typeIndexOf() { + return Variant<MVariantEmpty, Types...>::template typeIndexOf<T>(); +} + +template <typename... Types> +MVariant<Types...>::MVariant() {} + +template <typename... Types> +MVariant<Types...>::MVariant(MVariant const& x) + : m_variant(x.m_variant) {} + +template <typename... Types> +MVariant<Types...>::MVariant(MVariant&& x) { + m_variant = move(x.m_variant); + x.m_variant = MVariantEmpty(); +} + +template <typename... Types> +MVariant<Types...>::MVariant(Variant<Types...> const& x) { + operator=(x); +} + +template <typename... Types> +MVariant<Types...>::MVariant(Variant<Types...>&& x) { + operator=(move(x)); +} + +template <typename... Types> +template <typename T, typename> +MVariant<Types...>::MVariant(T const& x) + : m_variant(x) {} + +template <typename... Types> +template <typename T, typename> +MVariant<Types...>::MVariant(T&& x) + : m_variant(forward<T>(x)) {} + +template <typename... Types> +MVariant<Types...>::~MVariant() {} + +template <typename... Types> +MVariant<Types...>& MVariant<Types...>::operator=(MVariant const& x) { + try { + m_variant = x.m_variant; + } catch (...) { + if (m_variant.invalid()) + m_variant = MVariantEmpty(); + throw; + } + return *this; +} + +template <typename... Types> +MVariant<Types...>& MVariant<Types...>::operator=(MVariant&& x) { + try { + m_variant = move(x.m_variant); + } catch (...) { + if (m_variant.invalid()) + m_variant = MVariantEmpty(); + throw; + } + return *this; +} + +template <typename... Types> +template <typename T, typename> +MVariant<Types...>& MVariant<Types...>::operator=(T const& x) { + try { + m_variant = x; + } catch (...) { + if (m_variant.invalid()) + m_variant = MVariantEmpty(); + throw; + } + return *this; +} + +template <typename... Types> +template <typename T, typename> +MVariant<Types...>& MVariant<Types...>::operator=(T&& x) { + try { + m_variant = forward<T>(x); + } catch (...) { + if (m_variant.invalid()) + m_variant = MVariantEmpty(); + throw; + } + return *this; +} + +template <typename... Types> +MVariant<Types...>& MVariant<Types...>::operator=(Variant<Types...> const& x) { + x.call([this](auto const& t) { + *this = t; + }); + return *this; +} + +template <typename... Types> +MVariant<Types...>& MVariant<Types...>::operator=(Variant<Types...>&& x) { + x.call([this](auto& t) { + *this = move(t); + }); + return *this; +} + +template <typename... Types> +bool MVariant<Types...>::operator==(MVariant const& x) const { + return m_variant == x.m_variant; +} + +template <typename... Types> +bool MVariant<Types...>::operator!=(MVariant const& x) const { + return m_variant != x.m_variant; +} + +template <typename... Types> +bool MVariant<Types...>::operator<(MVariant const& x) const { + return m_variant < x.m_variant; +} + +template <typename... Types> +template <typename T, typename> +bool MVariant<Types...>::operator==(T const& x) const { + return m_variant == x; +} + +template <typename... Types> +template <typename T, typename> +bool MVariant<Types...>::operator!=(T const& x) const { + return m_variant != x; +} + +template <typename... Types> +template <typename T, typename> +bool MVariant<Types...>::operator<(T const& x) const { + return m_variant < x; +} + +template <typename... Types> +template <typename T, typename> +T const& MVariant<Types...>::get() const { + return m_variant.template get<T>(); +} + +template <typename... Types> +template <typename T, typename> +T& MVariant<Types...>::get() { + return m_variant.template get<T>(); +} + +template <typename... Types> +template <typename T, typename> +Maybe<T> MVariant<Types...>::maybe() const { + return m_variant.template maybe<T>(); +} + +template <typename... Types> +template <typename T, typename> +T const* MVariant<Types...>::ptr() const { + return m_variant.template ptr<T>(); +} + +template <typename... Types> +template <typename T, typename> +T* MVariant<Types...>::ptr() { + return m_variant.template ptr<T>(); +} + +template <typename... Types> +template <typename T, typename> +bool MVariant<Types...>::is() const { + return m_variant.template is<T>(); +} + +template <typename... Types> +template <typename T, typename> +T MVariant<Types...>::take() { + T t = move(m_variant.template get<T>()); + m_variant = MVariantEmpty(); + return t; +} + +template <typename... Types> +Variant<Types...> MVariant<Types...>::value() const { + if (empty()) + throw BadVariantCast(); + + Variant<Types...> r; + call([&r](auto const& v) { + r = v; + }); + return r; +} + +template <typename... Types> +Variant<Types...> MVariant<Types...>::takeValue() { + if (empty()) + throw BadVariantCast(); + + Variant<Types...> r; + call([&r](auto& v) { + r = move(v); + }); + m_variant = MVariantEmpty(); + return r; +} + +template <typename... Types> +bool MVariant<Types...>::empty() const { + return m_variant.template is<MVariantEmpty>(); +} + +template <typename... Types> +void MVariant<Types...>::reset() { + m_variant = MVariantEmpty(); +} + +template <typename... Types> +MVariant<Types...>::operator bool() const { + return !empty(); +} + +template <typename... Types> +template <typename Function> +void MVariant<Types...>::call(Function&& function) { + m_variant.call(RefCaller<Function>(forward<Function>(function))); +} + +template <typename... Types> +template <typename Function> +void MVariant<Types...>::call(Function&& function) const { + m_variant.call(ConstRefCaller<Function>(forward<Function>(function))); +} + +template <typename... Types> +VariantTypeIndex MVariant<Types...>::typeIndex() const { + return m_variant.typeIndex(); +} + +template <typename... Types> +void MVariant<Types...>::makeType(VariantTypeIndex typeIndex) { + m_variant.makeType(typeIndex); +} + +template <typename... Types> +bool MVariant<Types...>::MVariantEmpty::operator==(MVariantEmpty const&) const { + return true; +} + +template <typename... Types> +bool MVariant<Types...>::MVariantEmpty::operator<(MVariantEmpty const&) const { + return false; +} + +template <typename... Types> +template <typename Function> +MVariant<Types...>::RefCaller<Function>::RefCaller(Function&& function) + : function(forward<Function>(function)) {} + +template <typename... Types> +template <typename Function> +void MVariant<Types...>::RefCaller<Function>::operator()(MVariantEmpty&) {} + +template <typename... Types> +template <typename Function> +template <typename T> +void MVariant<Types...>::RefCaller<Function>::operator()(T& t) { + function(t); +} + +template <typename... Types> +template <typename Function> +MVariant<Types...>::ConstRefCaller<Function>::ConstRefCaller(Function&& function) + : function(forward<Function>(function)) {} + +template <typename... Types> +template <typename Function> +void MVariant<Types...>::ConstRefCaller<Function>::operator()(MVariantEmpty const&) {} + +template <typename... Types> +template <typename Function> +template <typename T> +void MVariant<Types...>::ConstRefCaller<Function>::operator()(T const& t) { + function(t); +} + +} + +#endif diff --git a/source/core/StarVector.hpp b/source/core/StarVector.hpp new file mode 100644 index 0000000..aaa5909 --- /dev/null +++ b/source/core/StarVector.hpp @@ -0,0 +1,925 @@ +#ifndef STAR_VECTOR_HPP +#define STAR_VECTOR_HPP + +#include "StarArray.hpp" +#include "StarMathCommon.hpp" +#include "StarAlgorithm.hpp" +#include "StarHash.hpp" + +namespace Star { + +template <typename T, size_t N> +class Vector : public Array<T, N> { +public: + typedef Array<T, N> Base; + + template <size_t P, typename T2 = void> + using Enable2D = typename std::enable_if<P == 2 && N == P, T2>::type; + + template <size_t P, typename T2 = void> + using Enable3D = typename std::enable_if<P == 3 && N == P, T2>::type; + + template <size_t P, typename T2 = void> + using Enable4D = typename std::enable_if<P == 4 && N == P, T2>::type; + + template <size_t P, typename T2 = void> + using Enable2DOrHigher = typename std::enable_if<P >= 2 && N == P, T2>::type; + + template <size_t P, typename T2 = void> + using Enable3DOrHigher = typename std::enable_if<P >= 3 && N == P, T2>::type; + + template <size_t P, typename T2 = void> + using Enable4DOrHigher = typename std::enable_if<P >= 4 && N == P, T2>::type; + + static Vector filled(T const& t); + + template <typename T2> + static Vector floor(Vector<T2, N> const& v); + + template <typename T2> + static Vector ceil(Vector<T2, N> const& v); + + template <typename T2> + static Vector round(Vector<T2, N> const& v); + + template <typename Iterator> + static Vector copyFrom(Iterator p); + + // Is zero-initialized (from Array) + Vector(); + + explicit Vector(T const& e1); + + template <typename... TN> + Vector(T const& e1, TN const&... rest); + + template <typename T2> + explicit Vector(Array<T2, N> const& v); + + template <typename T2, typename T3> + Vector(Array<T2, N - 1> const& u, T3 const& v); + + template <size_t N2> + Vector<T, N2> toSize() const; + Vector<T, 2> vec2() const; + Vector<T, 3> vec3() const; + Vector<T, 4> vec4() const; + + Vector piecewiseMultiply(Vector const& v2) const; + Vector piecewiseDivide(Vector const& v2) const; + + Vector piecewiseMin(Vector const& v2) const; + Vector piecewiseMax(Vector const& v2) const; + Vector piecewiseClamp(Vector const& min, Vector const& max) const; + + T min() const; + T max() const; + + T sum() const; + T product() const; + + template <typename Function> + Vector combine(Vector const& v, Function f) const; + + // Outputs angles in the range [0, pi] + T angleBetween(Vector const& v) const; + + // Angle between two normalized vectors. + T angleBetweenNormalized(Vector const& v) const; + + T magnitudeSquared() const; + T magnitude() const; + + void normalize(); + Vector normalized() const; + + Vector projectOnto(Vector const& v) const; + + Vector projectOntoNormalized(Vector const& v) const; + + void negate(); + + // Reverses order of components of vector + void reverse(); + + Vector abs() const; + Vector floor() const; + Vector ceil() const; + Vector round() const; + + void fill(T const& v); + void clamp(T const& min, T const& max); + + template <typename Function> + void transform(Function&& function); + + template <typename Function> + Vector<decltype(std::declval<Function>()(std::declval<T>())), N> transformed(Function&& function) const; + + Vector operator-() const; + + Vector operator+(Vector const& v) const; + Vector operator-(Vector const& v) const; + T operator*(Vector const& v) const; + Vector operator*(T s) const; + Vector operator/(T s) const; + Vector& operator+=(Vector const& v); + Vector& operator-=(Vector const& v); + Vector& operator*=(T s); + Vector& operator/=(T s); + + // Vector2 + + // Return vector rotated to given angle + template <size_t P = N> + static Enable2D<P, Vector> withAngle(T angle, T magnitude = 1); + + template <size_t P = N> + static Enable2D<P, T> angleBetween2(Vector const& u, Vector const& v); + template <size_t P = N> + static Enable2D<P, T> angleFormedBy2(Vector const& a, Vector const& b, Vector const& c); + template <size_t P = N> + static Enable2D<P, T> angleFormedBy2(Vector const& a, Vector const& b, Vector const& c, std::function<Vector(Vector, Vector)> const& diff); + + template <size_t P = N> + Enable2D<P, Vector> rotate(T angle) const; + + // Faster than rotate(Constants::pi/2). + template <size_t P = N> + Enable2D<P, Vector> rot90() const; + + // Angle of vector on 2d plane, in the range [-pi, pi] + template <size_t P = N> + Enable2D<P, T> angle() const; + + // Returns polar coordinates of this cartesian vector + template <size_t P = N> + Enable2D<P, Vector> toPolar() const; + + // Returns cartesian coordinates of this polar vector + template <size_t P = N> + Enable2D<P, Vector> toCartesian() const; + + template <size_t P = N> + Enable2DOrHigher<P, T> const& x() const; + template <size_t P = N> + Enable2DOrHigher<P, T> const& y() const; + + template <size_t P = N> + Enable2DOrHigher<P> setX(T const& t); + template <size_t P = N> + Enable2DOrHigher<P> setY(T const& t); + + // Vector3 + + template <size_t P = N> + static Enable3D<P, Vector> fromAngles(T psi, T theta); + template <size_t P = N> + static Enable3D<P, Vector> fromAnglesEnu(T psi, T theta); + template <size_t P = N> + static Enable3D<P, T> tripleScalarProduct(Vector const& u, Vector const& v, Vector const& w); + template <size_t P = N> + static Enable3D<P, T> angle(Vector const& v1, Vector const& v2); + + template <size_t P = N> + Enable3D<P, T> psi() const; + template <size_t P = N> + Enable3D<P, T> theta() const; + template <size_t P = N> + Enable3D<P, Vector<T, 2>> eulers() const; + + template <size_t P = N> + Enable3D<P, T> psiEnu() const; + template <size_t P = N> + Enable3D<P, T> thetaEnu() const; + + template <size_t P = N> + Enable3D<P, Vector> nedToEnu() const; + template <size_t P = N> + Enable3D<P, Vector> enuToNed() const; + + template <size_t P = N> + Enable3DOrHigher<P, T> const& z() const; + + template <size_t P = N> + Enable3DOrHigher<P> setZ(T const& t); + + // Vector4 + + template <size_t P = N> + Enable4DOrHigher<P, T> const& w() const; + + template <size_t P = N> + Enable4DOrHigher<P> setW(T const& t); + +private: + using Base::size; + using Base::empty; +}; + +typedef Vector<int, 2> Vec2I; +typedef Vector<unsigned, 2> Vec2U; +typedef Vector<float, 2> Vec2F; +typedef Vector<double, 2> Vec2D; +typedef Vector<uint8_t, 2> Vec2B; +typedef Vector<size_t, 2> Vec2S; + +typedef Vector<int, 3> Vec3I; +typedef Vector<unsigned, 3> Vec3U; +typedef Vector<float, 3> Vec3F; +typedef Vector<double, 3> Vec3D; +typedef Vector<uint8_t, 3> Vec3B; +typedef Vector<size_t, 3> Vec3S; + +typedef Vector<int, 4> Vec4I; +typedef Vector<unsigned, 4> Vec4U; +typedef Vector<float, 4> Vec4F; +typedef Vector<double, 4> Vec4D; +typedef Vector<uint8_t, 4> Vec4B; +typedef Vector<size_t, 4> Vec4S; + +template <typename T, size_t N> +std::ostream& operator<<(std::ostream& os, Vector<T, N> const& v); + +template <typename T, size_t N> +Vector<T, N> operator*(T s, Vector<T, N> v); + +template <typename T, size_t N> +Vector<T, N> vnorm(Vector<T, N> v); + +template <typename T, size_t N> +T vmag(Vector<T, N> const& v); + +template <typename T, size_t N> +T vmagSquared(Vector<T, N> const& v); + +template <typename T, size_t N> +Vector<T, N> vmin(Vector<T, N> const& a, Vector<T, N> const& b); + +template <typename T, size_t N> +Vector<T, N> vmax(Vector<T, N> const& a, Vector<T, N> const& b); + +template <typename T, size_t N> +Vector<T, N> vclamp(Vector<T, N> const& a, Vector<T, N> const& min, Vector<T, N> const& max); + +template <typename VectorType> +VectorType vmult(VectorType const& a, VectorType const& b); + +template <typename VectorType> +VectorType vdiv(VectorType const& a, VectorType const& b); + +// Returns the cross product +template <typename T> +Vector<T, 3> operator^(Vector<T, 3> v1, Vector<T, 3> v2); + +// Returns the cross product / determinant +template <typename T> +T operator^(Vector<T, 2> const& v1, Vector<T, 2> const& v2); + +template <typename T, size_t N> +struct hash<Vector<T, N>> : hash<Array<T, N>> {}; + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::filled(T const& t) { + Vector v; + for (size_t i = 0; i < N; ++i) + v[i] = t; + return v; +} + +template <typename T, size_t N> +template <typename T2> +Vector<T, N> Vector<T, N>::floor(Vector<T2, N> const& v) { + Vector vec; + for (size_t i = 0; i < N; ++i) + vec[i] = Star::floor(v[i]); + return vec; +} + +template <typename T, size_t N> +template <typename T2> +Vector<T, N> Vector<T, N>::ceil(Vector<T2, N> const& v) { + Vector vec; + for (size_t i = 0; i < N; ++i) + vec[i] = Star::ceil(v[i]); + return vec; +} + +template <typename T, size_t N> +template <typename T2> +Vector<T, N> Vector<T, N>::round(Vector<T2, N> const& v) { + Vector vec; + for (size_t i = 0; i < N; ++i) + vec[i] = Star::round(v[i]); + return vec; +} + +template <typename T, size_t N> +template <typename Iterator> +Vector<T, N> Vector<T, N>::copyFrom(Iterator p) { + Vector v; + for (size_t i = 0; i < N; ++i) + v[i] = *(p++); + return v; +} + +template <typename T, size_t N> +Vector<T, N>::Vector() {} + +template <typename T, size_t N> +Vector<T, N>::Vector(T const& e1) + : Base(e1) {} + +template <typename T, size_t N> +template <typename... TN> +Vector<T, N>::Vector(T const& e1, TN const&... rest) + : Base(e1, rest...) {} + +template <typename T, size_t N> +template <typename T2> +Vector<T, N>::Vector(Array<T2, N> const& v) + : Base(v) {} + +template <typename T, size_t N> +template <typename T2, typename T3> +Vector<T, N>::Vector(Array<T2, N - 1> const& u, T3 const& v) { + for (size_t i = 0; i < N - 1; ++i) { + Base::operator[](i) = u[i]; + } + Base::operator[](N - 1) = v; +} + +template <typename T, size_t N> +template <size_t N2> +Vector<T, N2> Vector<T, N>::toSize() const { + Vector<T, N2> r; + size_t ns = Star::min(N2, N); + for (size_t i = 0; i < ns; ++i) + r[i] = (*this)[i]; + return r; +} + +template <typename T, size_t N> +Vector<T, 2> Vector<T, N>::vec2() const { + return toSize<2>(); +} + +template <typename T, size_t N> +Vector<T, 3> Vector<T, N>::vec3() const { + return toSize<3>(); +} + +template <typename T, size_t N> +Vector<T, 4> Vector<T, N>::vec4() const { + return toSize<4>(); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::piecewiseMultiply(Vector const& v2) const { + return combine(v2, std::multiplies<T>()); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::piecewiseDivide(Vector const& v2) const { + return combine(v2, std::divides<T>()); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::piecewiseMin(Vector const& v2) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = Star::min((*this)[i], v2[i]); + return r; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::piecewiseMax(Vector const& v2) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = Star::max((*this)[i], v2[i]); + return r; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::piecewiseClamp(Vector const& min, Vector const& max) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = Star::max(Star::min((*this)[i], max[i]), min[i]); + return r; +} + +template <typename T, size_t N> +T Vector<T, N>::min() const { + T s = (*this)[0]; + for (size_t i = 1; i < N; ++i) + s = Star::min(s, (*this)[i]); + return s; +} + +template <typename T, size_t N> +T Vector<T, N>::max() const { + T s = (*this)[0]; + for (size_t i = 1; i < N; ++i) + s = Star::max(s, (*this)[i]); + return s; +} + +template <typename T, size_t N> +T Vector<T, N>::sum() const { + T s = (*this)[0]; + for (size_t i = 1; i < N; ++i) + s += (*this)[i]; + return s; +} + +template <typename T, size_t N> +T Vector<T, N>::product() const { + T p = (*this)[0]; + for (size_t i = 1; i < N; ++i) + p *= (*this)[i]; + return p; +} + +template <typename T, size_t N> +template <typename Function> +Vector<T, N> Vector<T, N>::combine(Vector const& v, Function f) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = f((*this)[i], v[i]); + return r; +} + +template <typename T, size_t N> +T Vector<T, N>::angleBetween(Vector const& v) const { + return acos(this->normalized() * v.normalized()); +} + +template <typename T, size_t N> +T Vector<T, N>::angleBetweenNormalized(Vector const& v) const { + return acos(*this * v); +} + +template <typename T, size_t N> +T Vector<T, N>::magnitudeSquared() const { + T m = 0; + for (size_t i = 0; i < N; ++i) + m += square((*this)[i]); + return m; +} + +template <typename T, size_t N> +T Vector<T, N>::magnitude() const { + return sqrt(magnitudeSquared()); +} + +template <typename T, size_t N> +void Vector<T, N>::normalize() { + T m = magnitude(); + if (m != 0) + *this = (*this) / m; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::normalized() const { + T m = magnitude(); + if (m != 0) + return (*this) / m; + else + return *this; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::projectOnto(Vector const& v) const { + T m = v.magnitudeSquared(); + if (m != 0) + return projectOntoNormalized(v) / m; + else + return Vector(); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::projectOntoNormalized(Vector const& v) const { + return ((*this) * v) * v; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::operator-() const { + auto v = *this; + v.negate(); + return v; +} + +template <typename T, size_t N> +void Vector<T, N>::negate() { + for (size_t i = 0; i < N; ++i) + (*this)[i] = -(*this)[i]; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::abs() const { + Vector v; + for (size_t i = 0; i < N; ++i) + v[i] = fabs((*this)[i]); + return v; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::floor() const { + return floor(*this); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::ceil() const { + return ceil(*this); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::round() const { + return round(*this); +} + +template <typename T, size_t N> +void Vector<T, N>::reverse() { + std::reverse(Base::begin(), Base::end()); +} + +template <typename T, size_t N> +void Vector<T, N>::fill(T const& v) { + Base::fill(v); +} + +template <typename T, size_t N> +void Vector<T, N>::clamp(T const& min, T const& max) { + for (size_t i = 0; i < N; ++i) + (*this)[i] = Star::max(min, Star::min(max, (*this)[i])); +} + +template <typename T, size_t N> +template <typename Function> +void Vector<T, N>::transform(Function&& function) { + for (auto& e : *this) + e = function(e); +} + +template <typename T, size_t N> +template <typename Function> +Vector<decltype(std::declval<Function>()(std::declval<T>())), N> Vector<T, N>::transformed(Function&& function) const { + return Star::transform<Vector<decltype(std::declval<Function>()(std::declval<T>())), N>>(*this, function); +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::operator+(Vector const& v) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = (*this)[i] + v[i]; + return r; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::operator-(Vector const& v) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = (*this)[i] - v[i]; + return r; +} + +template <typename T, size_t N> +T Vector<T, N>::operator*(Vector const& v) const { + T sum = 0; + for (size_t i = 0; i < N; ++i) + sum += (*this)[i] * v[i]; + return sum; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::operator*(T s) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = (*this)[i] * s; + return r; +} + +template <typename T, size_t N> +Vector<T, N> Vector<T, N>::operator/(T s) const { + Vector r; + for (size_t i = 0; i < N; ++i) + r[i] = (*this)[i] / s; + return r; +} + +template <typename T, size_t N> +Vector<T, N>& Vector<T, N>::operator+=(Vector const& v) { + return (*this = *this + v); +} + +template <typename T, size_t N> +Vector<T, N>& Vector<T, N>::operator-=(Vector const& v) { + return (*this = *this - v); +} + +template <typename T, size_t N> +Vector<T, N>& Vector<T, N>::operator*=(T s) { + return (*this = *this * s); +} + +template <typename T, size_t N> +Vector<T, N>& Vector<T, N>::operator/=(T s) { + return (*this = *this / s); +} + +// Vector2 + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::withAngle(T angle, T magnitude) -> Enable2D<P, Vector<T, N>> { + return Vector(std::cos(angle) * magnitude, std::sin(angle) * magnitude); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::angleBetween2(Vector const& v1, Vector const& v2) -> Enable2D<P, T> { + // TODO: Inefficient + return v2.angle() - v1.angle(); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::angleFormedBy2(Vector const& a, Vector const& b, Vector const& c) -> Enable2D<P, T> { + return angleBetween2(b - a, b - c); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::angleFormedBy2( + Vector const& a, Vector const& b, Vector const& c, std::function<Vector(Vector, Vector)> const& diff) + -> Enable2D<P, T> { + return angleBetween2(diff(b, a), diff(b, c)); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::angle() const -> Enable2D<P, T> { + return atan2(Base::operator[](1), Base::operator[](0)); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::rotate(T a) const -> Enable2D<P, Vector<T, N>> { + // TODO: Need a Matrix2 + T cosa = std::cos(a); + T sina = std::sin(a); + return Vector( + Base::operator[](0) * cosa - Base::operator[](1) * sina, Base::operator[](0) * sina + Base::operator[](1) * cosa); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::rot90() const -> Enable2D<P, Vector<T, N>> { + return Vector(-y(), x()); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::toPolar() const -> Enable2D<P, Vector<T, N>> { + return Vector(angle(), Base::magnitude()); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::toCartesian() const -> Enable2D<P, Vector<T, N>> { + return vec2d(sin((*this)[0]) * (*this)[1], cos((*this)[0]) * (*this)[1]); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::x() const -> Enable2DOrHigher<P, T> const & { + return Base::operator[](0); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::y() const -> Enable2DOrHigher<P, T> const & { + return Base::operator[](1); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::setX(T const& t) -> Enable2DOrHigher<P> { + Base::operator[](0) = t; +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::setY(T const& t) -> Enable2DOrHigher<P> { + Base::operator[](1) = t; +} + +// Vector3 + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::tripleScalarProduct(Vector const& a, Vector const& b, Vector const& c) -> Enable3D<P, T> { + return a * (b ^ c); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::theta() const -> Enable3D<P, T> { + Vector<T, N> vn = norm(*this); + T tmp = fabs(vn.z()); + if (tmp > 0.99999) { + return tmp > 0.0 ? T(-Constants::pi / 2) : T(Constants::pi / 2); + } else { + return asin(-vn.z()); + } +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::psi() const -> Enable3D<P, T> { + Vector<T, N> vn = norm(*this); + T tmp = T(fabs(vn.z())); + if (tmp > 0.99999) { + return 0.0; + } else { + return T(atan2(vn.y(), vn.x())); + } +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::thetaEnu() const -> Enable3D<P, T> { + Vector<T, N> vn = norm(*this); + T tmp = fabs(vn.z()); + if (tmp > 0.99999) { + return tmp > 0.0 ? -Constants::pi / 2 : Constants::pi / 2; + } else { + return asin(vn.z()); + } +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::psiEnu() const -> Enable3D<P, T> { + Vector<T, N> vn = norm(*this); + T tmp = fabs(vn.z()); + if (tmp > 0.99999) { + return 0.0; + } else { + return atan2(vn.x(), vn.y()); + } +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::eulers() const -> Enable3D<P, Vector<T, 2>> { + T psi, theta; + Vector<T, N> vn = norm(*this); + T tmp = fabs(vn.z()); + if (tmp > 0.99999) { + psi = 0.0; + theta = tmp > 0.0 ? -Constants::pi / 2 : Constants::pi / 2; + } else { + psi = atan2(vn.y(), vn.x()); + theta = asin(-vn.z()); + } + return Vector<T, 2>(psi, theta); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::fromAngles(T psi, T theta) -> Enable3D<P, Vector<T, N>> { + Vec3F nv; + T cosTheta = T(cos(theta)); + + nv.x() = T(cos(psi)); + nv.y() = T(sin(psi)); + nv.x() *= cosTheta; + nv.y() *= cosTheta; + nv.z() = T(-sin(theta)); + return nv; +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::fromAnglesEnu(T psi, T theta) -> Enable3D<P, Vector<T, N>> { + Vector nv = fromAngles(psi, theta); + return Vector(nv.y(), nv.x(), -nv.z()); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::angle(Vector const& v1, Vector const& v2) -> Enable3D<P, T> { + return acos(Star::min(norm(v1) * norm(v2), 1.0)); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::nedToEnu() const -> Enable3D<P, Vector<T, N>> { + return Vector(y(), x(), -z()); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::enuToNed() const -> Enable3D<P, Vector<T, N>> { + return Vector(y(), x(), -z()); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::z() const -> Enable3DOrHigher<P, T> const & { + return Base::operator[](2); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::setZ(T const& t) -> Enable3DOrHigher<P> { + Base::operator[](2) = t; +} + +// Vector4 + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::w() const -> Enable4DOrHigher<P, T> const & { + return Base::operator[](3); +} + +template <typename T, size_t N> +template <size_t P> +auto Vector<T, N>::setW(T const& t) -> Enable4DOrHigher<P> { + Base::operator[](3) = t; +} + +// Free Functions + +template <typename T, size_t N> +std::ostream& operator<<(std::ostream& os, Vector<T, N> const& v) { + os << '('; + for (size_t i = 0; i < N; ++i) { + os << v[i]; + if (i != N - 1) + os << ", "; + } + os << ')'; + return os; +} + +template <typename T, size_t N> +Vector<T, N> operator*(T s, Vector<T, N> v) { + return v * s; +} + +template <typename T, size_t N> +Vector<T, N> vnorm(Vector<T, N> v) { + return v.normalized(); +} + +template <typename T, size_t N> +T vmag(Vector<T, N> const& v) { + return v.magnitude(); +} + +template <typename T, size_t N> +T vmagSquared(Vector<T, N> const& v) { + return v.magnitudeSquared(); +} + +template <typename T, size_t N> +Vector<T, N> vmin(Vector<T, N> const& a, Vector<T, N> const& b) { + return a.piecewiseMin(b); +} + +template <typename T, size_t N> +Vector<T, N> vmax(Vector<T, N> const& a, Vector<T, N> const& b) { + return a.piecewiseMax(b); +} + +template <typename T, size_t N> +Vector<T, N> vclamp(Vector<T, N> const& a, Vector<T, N> const& min, Vector<T, N> const& max) { + return a.piecewiseClamp(min, max); +} + +template <typename VectorType> +VectorType vmult(VectorType const& a, VectorType const& b) { + return a.piecewiseMultiply(b); +} + +template <typename VectorType> +VectorType vdiv(VectorType const& a, VectorType const& b) { + return a.piecewiseDivide(b); +} + +template <typename T> +Vector<T, 3> operator^(Vector<T, 3> v1, Vector<T, 3> v2) { + return Vector<T, 3>(v1[1] * v2[2] - v1[2] * v2[1], v1[2] * v2[0] - v1[0] * v2[2], v1[0] * v2[1] - v1[1] * v2[0]); +} + +template <typename T> +T operator^(Vector<T, 2> const& v1, Vector<T, 2> const& v2) { + return v1[0] * v2[1] - v1[1] * v2[0]; +} + +} + +#endif diff --git a/source/core/StarVlqEncoding.hpp b/source/core/StarVlqEncoding.hpp new file mode 100644 index 0000000..2139ed1 --- /dev/null +++ b/source/core/StarVlqEncoding.hpp @@ -0,0 +1,106 @@ +#ifndef STAR_VLQ_ENCODING_HPP +#define STAR_VLQ_ENCODING_HPP + +#include "StarMemory.hpp" + +namespace Star { + +// Write an unsigned integer as a VLQ (Variable Length Quantity). Writes the +// integer in 7 byte chunks, with the 8th bit of each octet indicates whether +// another chunk follows. Endianness independent, as the chunks are always +// written most significant first. Returns number of octet written (writes a +// maximum of a 64 bit integer, so a maximum of 10) +template <typename OutputIterator> +size_t writeVlqU(uint64_t x, OutputIterator out) { + size_t i; + for (i = 9; i > 0; --i) { + if (x & ((uint64_t)(127) << (i * 7))) + break; + } + + for (size_t j = 0; j < i; ++j) + *out++ = (uint8_t)((x >> ((i - j) * 7)) & 127) | 128; + + *out++ = (uint8_t)(x & 127); + return i + 1; +} + +inline size_t vlqUSize(uint64_t x) { + size_t i; + for (i = 9; i > 0; --i) { + if (x & ((uint64_t)(127) << (i * 7))) + break; + } + return i + 1; +} + +// Read a VLQ (Variable Length Quantity) encoded unsigned integer. Returns +// number of bytes read. Reads a *maximum of 10 bytes*, cannot read a larger +// than 64 bit integer! If no end marker is found within 'maxBytes' or 10 +// bytes, whichever is smaller, then will return NPos to signal error. +template <typename InputIterator> +size_t readVlqU(uint64_t& x, InputIterator in, size_t maxBytes = 10) { + x = 0; + for (size_t i = 0; i < min<size_t>(10, maxBytes); ++i) { + uint8_t oct = *in++; + x = (x << 7) | (uint64_t)(oct & 127); + if (!(oct & 128)) + return i + 1; + } + + return NPos; +} + +// Write a VLQ (Variable Length Quantity) encoded signed integer. Encoded by +// making the sign bit the least significant bit in the integer. Returns +// number of bytes written. +template <typename OutputIterator> +size_t writeVlqI(int64_t v, OutputIterator out) { + uint64_t target; + + // If negative, then add 1 to properly encode -2^63 + if (v < 0) + target = ((-(v + 1)) << 1) | 1; + else + target = v << 1; + + return writeVlqU(target, out); +} + +inline size_t vlqISize(int64_t v) { + uint64_t target; + + // If negative, then add 1 to properly encode -2^63 + if (v < 0) + target = ((-(v + 1)) << 1) | 1; + else + target = v << 1; + + return vlqUSize(target); +} + +// Read a VLQ (Variable Length Quantity) encoded signed integer. Returns +// number of bytes read. Reads a *maximum of 10 bytes*, cannot read a larger +// than 64 bit integer! If no end marker is found within 'maxBytes' or 10 +// bytes, whichever is smaller, then will return NPos to signal error. +template <typename InputIterator> +size_t readVlqI(int64_t& v, InputIterator in, size_t maxBytes = 10) { + uint64_t source; + size_t bytes = readVlqU(source, in, maxBytes); + if (bytes == NPos) + return NPos; + + bool negative = (source & 1); + + // If negative, then need to undo the +1 transformation to encode -2^63 + if (negative) + v = -(int64_t)(source >> 1) - 1; + else + v = (int64_t)(source >> 1); + + return bytes; +} + +} + +#endif diff --git a/source/core/StarWeightedPool.hpp b/source/core/StarWeightedPool.hpp new file mode 100644 index 0000000..c335d31 --- /dev/null +++ b/source/core/StarWeightedPool.hpp @@ -0,0 +1,196 @@ +#ifndef STAR_WEIGHTED_POOL_HPP +#define STAR_WEIGHTED_POOL_HPP + +#include "StarRandom.hpp" + +namespace Star { + +template <typename Item> +struct WeightedPool { +public: + typedef pair<double, Item> ItemsType; + typedef List<ItemsType> ItemsList; + + WeightedPool(); + + template <typename Container> + explicit WeightedPool(Container container); + + void add(double weight, Item item); + void clear(); + + ItemsList const& items() const; + + size_t size() const; + pair<double, Item> const& at(size_t index) const; + double weight(size_t index) const; + Item const& item(size_t index) const; + bool empty() const; + + // Return item using the given randomness source + Item select(RandomSource& rand) const; + // Return item using the global randomness source + Item select() const; + // Return item using fast static randomness from the given seed + Item select(uint64_t seed) const; + + // Return a list of n items which are selected uniquely (by index), where + // n is the lesser of the desiredCount and the size of the pool. + // This INFLUENCES PROBABILITIES so it should not be used where a + // correct statistical distribution is required. + List<Item> selectUniques(size_t desiredCount) const; + List<Item> selectUniques(size_t desiredCount, uint64_t seed) const; + + size_t selectIndex(RandomSource& rand) const; + size_t selectIndex() const; + size_t selectIndex(uint64_t seed) const; + +private: + size_t selectIndex(double target) const; + + ItemsList m_items; + double m_totalWeight; +}; + +template <typename Item> +WeightedPool<Item>::WeightedPool() + : m_totalWeight(0.0) {} + +template <typename Item> +template <typename Container> +WeightedPool<Item>::WeightedPool(Container container) + : WeightedPool() { + for (auto const& pair : container) + add(get<0>(pair), get<1>(pair)); +} + +template <typename Item> +void WeightedPool<Item>::add(double weight, Item item) { + if (weight <= 0.0) + return; + + m_items.append({weight, move(item)}); + m_totalWeight += weight; +} + +template <typename Item> +void WeightedPool<Item>::clear() { + m_items.clear(); + m_totalWeight = 0.0; +} + +template <typename Item> +auto WeightedPool<Item>::items() const -> ItemsList const & { + return m_items; +} + +template <typename Item> +size_t WeightedPool<Item>::size() const { + return m_items.count(); +} + +template <typename Item> +pair<double, Item> const& WeightedPool<Item>::at(size_t index) const { + return m_items.at(index); +} + +template <typename Item> +double WeightedPool<Item>::weight(size_t index) const { + return at(index).first; +} + +template <typename Item> +Item const& WeightedPool<Item>::item(size_t index) const { + return at(index).second; +} + +template <typename Item> +bool WeightedPool<Item>::empty() const { + return m_items.empty(); +} + +template <typename Item> +Item WeightedPool<Item>::select(RandomSource& rand) const { + if (m_items.empty()) + return Item(); + + return m_items[selectIndex(rand)].second; +} + +template <typename Item> +Item WeightedPool<Item>::select() const { + if (m_items.empty()) + return Item(); + + return m_items[selectIndex()].second; +} + +template <typename Item> +Item WeightedPool<Item>::select(uint64_t seed) const { + if (m_items.empty()) + return Item(); + + return m_items[selectIndex(seed)].second; +} + +template <typename Item> +List<Item> WeightedPool<Item>::selectUniques(size_t desiredCount) const { + return selectUniques(desiredCount, Random::randu64()); +} + +template <typename Item> +List<Item> WeightedPool<Item>::selectUniques(size_t desiredCount, uint64_t seed) const { + size_t targetCount = std::min(desiredCount, size()); + Set<size_t> indices; + while (indices.size() < targetCount) + indices.add(selectIndex(++seed)); + List<Item> result; + for (size_t i : indices) + result.append(m_items[i].second); + return result; +} + +template <typename Item> +size_t WeightedPool<Item>::selectIndex(RandomSource& rand) const { + return selectIndex(rand.randd()); +} + +template <typename Item> +size_t WeightedPool<Item>::selectIndex() const { + return selectIndex(Random::randd()); +} + +template <typename Item> +size_t WeightedPool<Item>::selectIndex(uint64_t seed) const { + return selectIndex(staticRandomDouble(seed)); +} + +template <typename Item> +size_t WeightedPool<Item>::selectIndex(double target) const { + if (m_items.empty()) + return NPos; + + // Test a randomly generated target against each weighted item in turn, and + // see if that weighted item's weight value crosses the target. This way, a + // random item is picked from the list, but (roughly) weighted to be + // proportional to its weight over the weight of all entries. + // + // TODO: This is currently O(n), but can easily be made O(log(n)) by using a + // tree. If this shows up in performance measurements, this is an obvious + // improvement. + + double accumulatedWeight = 0.0f; + for (size_t i = 0; i < m_items.size(); ++i) { + accumulatedWeight += m_items[i].first / m_totalWeight; + if (target <= accumulatedWeight) + return i; + } + + // If we haven't crossed the target, just assume floating point error has + // caused us to not quite make it to the last item. + return m_items.size() - 1; +} + +} + +#endif diff --git a/source/core/StarWorkerPool.cpp b/source/core/StarWorkerPool.cpp new file mode 100644 index 0000000..fa37a57 --- /dev/null +++ b/source/core/StarWorkerPool.cpp @@ -0,0 +1,166 @@ +#include "StarWorkerPool.hpp" +#include "StarIterator.hpp" +#include "StarMathCommon.hpp" + +namespace Star { + +bool WorkerPoolHandle::done() const { + MutexLocker locker(m_impl->mutex); + return m_impl->done; +} + +bool WorkerPoolHandle::wait(unsigned millis) const { + MutexLocker locker(m_impl->mutex); + + if (!m_impl->done && millis != 0) + m_impl->condition.wait(m_impl->mutex, millis); + + if (m_impl->exception) + std::rethrow_exception(m_impl->exception); + + return m_impl->done; +} + +bool WorkerPoolHandle::poll() const { + return wait(0); +} + +void WorkerPoolHandle::finish() const { + MutexLocker locker(m_impl->mutex); + + if (!m_impl->done) + m_impl->condition.wait(m_impl->mutex); + + if (m_impl->exception) + std::rethrow_exception(m_impl->exception); + + return; +} + +WorkerPoolHandle::Impl::Impl() : done(false) {} + +WorkerPoolHandle::WorkerPoolHandle(shared_ptr<Impl> impl) : m_impl(move(impl)) {} + +WorkerPool::WorkerPool(String name) : m_name(move(name)) {} + +WorkerPool::WorkerPool(String name, unsigned threadCount) : WorkerPool(move(name)) { + start(threadCount); +} + +WorkerPool::~WorkerPool() { + stop(); +} + +WorkerPool::WorkerPool(WorkerPool&&) = default; +WorkerPool& WorkerPool::operator=(WorkerPool&&) = default; + +void WorkerPool::start(unsigned threadCount) { + MutexLocker threadLock(m_threadMutex); + + for (auto const& workerThread : m_workerThreads) + workerThread->shouldStop = true; + + m_workCondition.broadcast(); + m_workerThreads.clear(); + + for (size_t i = m_workerThreads.size(); i < threadCount; ++i) + m_workerThreads.append(make_unique<WorkerThread>(this)); +} + +void WorkerPool::stop() { + MutexLocker threadLock(m_threadMutex); + for (auto const& workerThread : m_workerThreads) + workerThread->shouldStop = true; + + { + // Must hold the work lock while broadcasting to ensure that any worker + // threads that might wait without stopping actually get the signal. + MutexLocker workLock(m_workMutex); + m_workCondition.broadcast(); + } + + m_workerThreads.clear(); +} + +void WorkerPool::finish() { + // This is kind of a weird way to "wait" until all the pending work is + // finished. In order for the currently active worker threads to + // cooperatively complete the remaining work, the work lock must not be held + // the entire time (then just this thread would be the one finishing the + // work). Instead, the calling thread joins in on the action and tries to + // finish work while yielding to the other threads after each completed job. + MutexLocker workMutex(m_workMutex); + while (!m_pendingWork.empty()) { + auto firstWork = m_pendingWork.takeFirst(); + workMutex.unlock(); + firstWork(); + Thread::yield(); + workMutex.lock(); + } + workMutex.unlock(); + + stop(); +} + +WorkerPoolHandle WorkerPool::addWork(function<void()> work) { + // Construct a worker pool handle and wrap the work to signal the handle when + // finished. Set the result to empty string if successful and to the content + // of the exception if an exception is thrown. + auto workerPoolHandleImpl = make_shared<WorkerPoolHandle::Impl>(); + queueWork([workerPoolHandleImpl, work]() { + try { + work(); + MutexLocker handleLocker(workerPoolHandleImpl->mutex); + workerPoolHandleImpl->done = true; + workerPoolHandleImpl->condition.broadcast(); + } catch (...) { + MutexLocker handleLocker(workerPoolHandleImpl->mutex); + workerPoolHandleImpl->done = true; + workerPoolHandleImpl->exception = std::current_exception(); + workerPoolHandleImpl->condition.broadcast(); + } + }); + + return workerPoolHandleImpl; +} + +WorkerPool::WorkerThread::WorkerThread(WorkerPool* parent) + : Thread(strf("WorkerThread for WorkerPool '%s'", parent->m_name)), + parent(parent), + shouldStop(false), + waiting(false) { + start(); +} + +WorkerPool::WorkerThread::~WorkerThread() { + join(); +} + +void WorkerPool::WorkerThread::run() { + MutexLocker workLock(parent->m_workMutex); + while (true) { + if (shouldStop) + break; + + if (parent->m_pendingWork.empty()) { + waiting = true; + parent->m_workCondition.wait(parent->m_workMutex); + waiting = false; + } + + if (!parent->m_pendingWork.empty()) { + auto work = parent->m_pendingWork.takeFirst(); + workLock.unlock(); + work(); + workLock.lock(); + } + } +} + +void WorkerPool::queueWork(function<void()> work) { + MutexLocker workLock(m_workMutex); + m_pendingWork.append(move(work)); + m_workCondition.signal(); +} + +} diff --git a/source/core/StarWorkerPool.hpp b/source/core/StarWorkerPool.hpp new file mode 100644 index 0000000..8d87bcb --- /dev/null +++ b/source/core/StarWorkerPool.hpp @@ -0,0 +1,222 @@ +#ifndef STAR_WORKER_POOL_HPP +#define STAR_WORKER_POOL_HPP + +#include "StarThread.hpp" + +namespace Star { + +STAR_EXCEPTION(WorkerPoolException, StarException); + +STAR_CLASS(WorkerPool); + +// Shareable handle for a WorkerPool computation that does not produce any +// value. +class WorkerPoolHandle { +public: + // Returns true if the work is completed (either due to error or actual + // completion, will not re-throw) + bool done() const; + + // Waits up to given millis for the computation to finish. Returns true if + // the computation finished within the allotted time, false otherwise. If + // the computation is finished but it threw an exception, it will be + // re-thrown here. + bool wait(unsigned millis) const; + + // synonym for wait(0) + bool poll() const; + + // Wait until the computation finishes completely. If the computation threw + // an exception it will be re-thrown by this method. + void finish() const; + +private: + friend WorkerPool; + + struct Impl { + Impl(); + + Mutex mutex; + ConditionVariable condition; + atomic<bool> done; + std::exception_ptr exception; + }; + + WorkerPoolHandle(shared_ptr<Impl> impl); + + shared_ptr<Impl> m_impl; +}; + +// Shareable handle for a WorkerPool computation that produces a value. +template <typename ResultType> +class WorkerPoolPromise { +public: + // Returns true if the work is completed (either due to error or actual + // completion, will not re-throw) + bool done() const; + + // Waits for the given amount of time for the work to be completed. If the + // work is completed, returns true. If the producer function throws for any + // reason, this method will re-throw the exception. If millis is zero, does + // not wait at all simply polls to see if the computation is finished. + bool wait(unsigned millis) const; + + // synonym for wait(0) + bool poll() const; + + // Blocks until the work is done, and returns the result. May be called + // multiple times to access the result. If the computation threw + // an exception it will be re-thrown by this method. + ResultType& get(); + ResultType const& get() const; + +private: + friend WorkerPool; + + struct Impl { + Mutex mutex; + ConditionVariable condition; + Maybe<ResultType> result; + std::exception_ptr exception; + }; + + WorkerPoolPromise(shared_ptr<Impl> impl); + + shared_ptr<Impl> m_impl; +}; + +class WorkerPool { +public: + // Creates a stopped pool + WorkerPool(String name); + // Creates a started pool + WorkerPool(String name, unsigned threadCount); + ~WorkerPool(); + + WorkerPool(WorkerPool&&); + WorkerPool& operator=(WorkerPool&&); + + // Start the thread pool with the given thread count range, or if it is + // already started, reconfigure the thread counts. + void start(unsigned threadCount); + + // Stop the thread pool, not necessarily finishing any pending jobs (may + // leave pending jobs on the queue). + void stop(); + + // Try to finish any remaining jobs, then stop the thread pool. This method + // must not be called if the worker pool will continuously receive new work, + // as it may not ever complete if that is the case. The work queue must + // eventually become empty for this to properly return. + void finish(); + + // Add the given work to the pool and return a handle for the work. It not + // required that the caller of this method hold on to the worker handle, the + // work will be managed and completed regardless of the WorkerPoolHandle + // lifetime. + WorkerPoolHandle addWork(function<void()> work); + + // Like addWork, but the worker is expected to produce some result. The + // returned promise can be used to get this return value once the producer is + // complete. + template <typename ResultType> + WorkerPoolPromise<ResultType> addProducer(function<ResultType()> producer); + +private: + class WorkerThread : public Thread { + public: + // Starts automatically + WorkerThread(WorkerPool* parent); + ~WorkerThread(); + + void run() override; + + WorkerPool* parent; + atomic<bool> shouldStop; + atomic<bool> waiting; + }; + + void queueWork(function<void()> work); + + String m_name; + Mutex m_threadMutex; + List<unique_ptr<WorkerThread>> m_workerThreads; + + Mutex m_workMutex; + ConditionVariable m_workCondition; + Deque<function<void()>> m_pendingWork; +}; + +template <typename ResultType> +bool WorkerPoolPromise<ResultType>::done() const { + MutexLocker locker(m_impl->mutex); + return m_impl->result || m_impl->exception; +} + +template <typename ResultType> +bool WorkerPoolPromise<ResultType>::wait(unsigned millis) const { + MutexLocker locker(m_impl->mutex); + + if (!m_impl->result && !m_impl->exception && millis != 0) + m_impl->condition.wait(m_impl->mutex, millis); + + if (m_impl->exception) + std::rethrow_exception(m_impl->exception); + + if (m_impl->result) + return true; + + return false; +} + +template <typename ResultType> +bool WorkerPoolPromise<ResultType>::poll() const { + return wait(0); +} + +template <typename ResultType> +ResultType& WorkerPoolPromise<ResultType>::get() { + MutexLocker locker(m_impl->mutex); + + if (!m_impl->result && !m_impl->exception) + m_impl->condition.wait(m_impl->mutex); + + if (m_impl->exception) + std::rethrow_exception(m_impl->exception); + + return *m_impl->result; +} + +template <typename ResultType> +ResultType const& WorkerPoolPromise<ResultType>::get() const { + return const_cast<WorkerPoolPromise*>(this)->get(); +} + +template <typename ResultType> +WorkerPoolPromise<ResultType>::WorkerPoolPromise(shared_ptr<Impl> impl) + : m_impl(move(impl)) {} + +template <typename ResultType> +WorkerPoolPromise<ResultType> WorkerPool::addProducer(function<ResultType()> producer) { + // Construct a worker pool promise and wrap the producer to signal the + // promise when finished. + auto workerPoolPromiseImpl = make_shared<typename WorkerPoolPromise<ResultType>::Impl>(); + queueWork([workerPoolPromiseImpl, producer]() { + try { + auto result = producer(); + MutexLocker promiseLocker(workerPoolPromiseImpl->mutex); + workerPoolPromiseImpl->result = move(result); + workerPoolPromiseImpl->condition.broadcast(); + } catch (...) { + MutexLocker promiseLocker(workerPoolPromiseImpl->mutex); + workerPoolPromiseImpl->exception = std::current_exception(); + workerPoolPromiseImpl->condition.broadcast(); + } + }); + + return workerPoolPromiseImpl; +} + +} + +#endif diff --git a/source/core/StarXXHash.hpp b/source/core/StarXXHash.hpp new file mode 100644 index 0000000..28ef202 --- /dev/null +++ b/source/core/StarXXHash.hpp @@ -0,0 +1,142 @@ +#ifndef STAR_XXHASH_HPP +#define STAR_XXHASH_HPP + +#include "StarString.hpp" +#include "StarByteArray.hpp" + +#define XXH_STATIC_LINKING_ONLY +#include "xxhash.h" + +namespace Star { + +class XXHash32 { +public: + XXHash32(uint32_t seed = 0); + + void push(char const* data, size_t length); + uint32_t digest(); + +private: + XXH32_state_s state; +}; + +class XXHash64 { +public: + XXHash64(uint64_t seed = 0); + + void push(char const* data, size_t length); + uint64_t digest(); + +private: + XXH64_state_s state; +}; + +uint32_t xxHash32(char const* source, size_t length); +uint32_t xxHash32(ByteArray const& in); +uint32_t xxHash32(String const& in); + +uint64_t xxHash64(char const* source, size_t length); +uint64_t xxHash64(ByteArray const& in); +uint64_t xxHash64(String const& in); + +#define XXHASH32_PRIMITIVE(TYPE, CAST_TYPE) \ + inline void xxHash32Push(XXHash32& hash, TYPE const& v) { \ + CAST_TYPE cv = v; \ + cv = toLittleEndian(cv); \ + hash.push((char const*)(&cv), sizeof(cv)); \ + } + +#define XXHASH64_PRIMITIVE(TYPE, CAST_TYPE) \ + inline void xxHash64Push(XXHash64& hash, TYPE const& v) { \ + CAST_TYPE cv = v; \ + cv = toLittleEndian(cv); \ + hash.push((char const*)(&cv), sizeof(cv)); \ + } + +XXHASH32_PRIMITIVE(bool, bool); +XXHASH32_PRIMITIVE(int, int32_t); +XXHASH32_PRIMITIVE(long, int64_t); +XXHASH32_PRIMITIVE(long long, int64_t); +XXHASH32_PRIMITIVE(unsigned int, uint32_t); +XXHASH32_PRIMITIVE(unsigned long, uint64_t); +XXHASH32_PRIMITIVE(unsigned long long, uint64_t); +XXHASH32_PRIMITIVE(float, float); +XXHASH32_PRIMITIVE(double, double); + +XXHASH64_PRIMITIVE(bool, bool); +XXHASH64_PRIMITIVE(int, int32_t); +XXHASH64_PRIMITIVE(long, int64_t); +XXHASH64_PRIMITIVE(long long, int64_t); +XXHASH64_PRIMITIVE(unsigned int, uint32_t); +XXHASH64_PRIMITIVE(unsigned long, uint64_t); +XXHASH64_PRIMITIVE(unsigned long long, uint64_t); +XXHASH64_PRIMITIVE(float, float); +XXHASH64_PRIMITIVE(double, double); + +inline void xxHash32Push(XXHash32& hash, char const* str) { + hash.push(str, strlen(str)); +} + +inline void xxHash32Push(XXHash32& hash, String const& str) { + hash.push(str.utf8Ptr(), str.size()); +} + +inline void xxHash64Push(XXHash64& hash, char const* str) { + hash.push(str, strlen(str)); +} + +inline void xxHash64Push(XXHash64& hash, String const& str) { + hash.push(str.utf8Ptr(), str.size()); +} + +inline XXHash32::XXHash32(uint32_t seed) { + XXH32_reset(&state, seed); +} + +inline void XXHash32::push(char const* data, size_t length) { + XXH32_update(&state, data, length); +} + +inline uint32_t XXHash32::digest() { + return XXH32_digest(&state); +} + +inline XXHash64::XXHash64(uint64_t seed) { + XXH64_reset(&state, seed); +} + +inline void XXHash64::push(char const* data, size_t length) { + XXH64_update(&state, data, length); +} + +inline uint64_t XXHash64::digest() { + return XXH64_digest(&state); +} + +inline uint32_t xxHash32(char const* source, size_t length) { + return XXH32(source, length, 0); +} + +inline uint32_t xxHash32(ByteArray const& in) { + return xxHash32(in.ptr(), in.size()); +} + +inline uint32_t xxHash32(String const& in) { + return xxHash32(in.utf8Ptr(), in.utf8Size()); +} + +inline uint64_t xxHash64(char const* source, size_t length) { + return XXH64(source, length, 0); +} + +inline uint64_t xxHash64(ByteArray const& in) { + return xxHash64(in.ptr(), in.size()); +} + +inline uint64_t xxHash64(String const& in) { + return xxHash64(in.utf8Ptr(), in.utf8Size()); +} + +} + +#endif |