diff --git a/WORKSPACE b/WORKSPACE index 148512d1c..0ec3b8a1b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -99,7 +99,7 @@ rocm_configure(name = "local_config_rocm") http_archive( name = "com_google_sentencepiece", - build_file = "@//patches:sentencepiece.BUILD", + build_file = "@//third_party:sentencepiece.BUILD", patch_args = ["-p1"], patches = ["@//patches:com_google_sentencepiece.diff"], sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754", @@ -111,7 +111,7 @@ http_archive( http_archive( name = "darts_clone", - build_file = "@//patches:darts_clone.BUILD", + build_file = "@//third_party:darts_clone.BUILD", patch_args = ["-p0"], patches = ["//patches:darts_no_exceptions.diff"], sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", @@ -161,6 +161,14 @@ http_archive( urls = ["https://github.com/MediaTek-NeuroPilot/tflite-neuron-delegate/archive/refs/heads/update_for_leroy.zip"], ) +http_archive( + name = "oleander_stemming_library", + build_file = "@//third_party:oleander_stemming_library.BUILD", + sha256 = "d4390e82590d67c73ac32629ddd4fc3ba0b6b293a2757612a2e76726c3752e0b", + strip_prefix = "OleanderStemmingLibrary-45eb3485f67b94d67bb883601ed65459975b3960", + urls = ["https://github.com/Blake-Madden/OleanderStemmingLibrary/archive/45eb3485f67b94d67bb883601ed65459975b3960.zip"], +) + new_git_repository( name = "org_mlperf_inference", build_file = "@//flutter/android/third_party:loadgen.BUILD", diff --git a/flutter/cpp/datasets/ifeval_utils/BUILD b/flutter/cpp/datasets/ifeval_utils/BUILD index c13cf74f2..70f3962ce 100644 --- a/flutter/cpp/datasets/ifeval_utils/BUILD +++ b/flutter/cpp/datasets/ifeval_utils/BUILD @@ -22,6 +22,7 @@ cc_library( name = "ifeval_utils", hdrs = [ "common.h", + "irregular-plurals.h", "json.h", "types.h", ], @@ -36,5 +37,6 @@ cc_library( }), deps = [ "@cld2", + "@oleander_stemming_library", ], ) diff --git a/flutter/cpp/datasets/ifeval_utils/common.h b/flutter/cpp/datasets/ifeval_utils/common.h index 629a55575..bba83ed4b 100644 --- a/flutter/cpp/datasets/ifeval_utils/common.h +++ b/flutter/cpp/datasets/ifeval_utils/common.h @@ -35,11 +35,14 @@ inline std::string tolower(std::string s) { return s; } -inline bool ends_with(const std::string& s, const std::string& suf) { - if (s.size() < suf.size()) return false; - std::string a = tolower(s.substr(s.size() - suf.size())); - std::string b = tolower(suf); - return a == b; +inline std::string to_lower_ascii(std::string s) { + for (char& c : s) + c = static_cast(std::tolower(static_cast(c))); + return s; +} + +inline bool is_word_char(unsigned char c) { + return std::isalnum(c) || c == '_'; } inline bool contains_string(const std::string& text, @@ -48,17 +51,25 @@ inline bool contains_string(const std::string& text, return h.find(n) != std::string::npos; } +inline bool ends_with(const std::string& s, const std::string& suf, + unsigned threshold) { + if (s.size() < suf.size()) return false; + std::string a = tolower(s.substr(s.size() - (suf.size() + threshold))); + std::string b = tolower(suf); + return threshold == 0 ? a == b : contains_string(a, b); +} + +inline bool starts_with(const std::string& s, const std::string& prf, + unsigned threshold) { + if (s.size() < prf.size()) return false; + std::string a = tolower(s.substr(0, prf.size() + threshold)); + std::string b = tolower(prf); + return threshold == 0 ? a == b : contains_string(a, b); +} + inline bool contains_word(const std::string& text, const std::string& word) { if (word.empty()) return false; - auto to_lower_ascii = [](std::string s) { - for (char& c : s) c = std::tolower(static_cast(c)); - return s; - }; - auto is_word_char = [](unsigned char c) { - return std::isalnum(c) || c == '_'; // match std::regex \b notion of "word" - }; - std::string t = to_lower_ascii(text); std::string w = to_lower_ascii(word); @@ -83,6 +94,39 @@ inline bool contains_none(const std::string& text, return true; } +inline size_t find_containing_word(const std::string& text, + const std::string& keyword, + std::string& containing_word, size_t pos) { + if (keyword.empty() || pos >= text.size()) return std::string::npos; + + std::string t = to_lower_ascii(text); + std::string k = to_lower_ascii(keyword); + + if ((pos = t.find(k, pos)) == std::string::npos) return std::string::npos; + + // Expand left to word boundary + size_t start = pos; + while (start > 0 && is_word_char(static_cast(t[start - 1]))) { + --start; + } + + // Expand right to word boundary + size_t end = pos + k.size(); + while (end < t.size() && is_word_char(static_cast(t[end]))) { + ++end; + } + + // Extract original (not lowercased) word + containing_word = text.substr(start, end - start); + return start; +} + +inline size_t find_containing_word(const std::string& text, + const std::string& keyword, + std::string& out_word) { + return find_containing_word(text, keyword, out_word, 0); +} + inline std::string remove_font_modifiers(const std::string& s) { std::string out; out.reserve(s.size()); @@ -115,14 +159,12 @@ inline std::string remove_font_modifiers(const std::string& s) { inline std::string remove_first_line(const std::string& s) { std::size_t pos = s.find('\n'); - return (pos == std::string::npos) ? std::string{} : s.substr(pos + 1); - // If there is no newline, removing the first line yields empty. + return (pos == std::string::npos) ? std::string(s) : s.substr(pos + 1); } inline std::string remove_last_line(const std::string& s) { std::size_t pos = s.rfind('\n'); - return (pos == std::string::npos) ? std::string{} : s.substr(0, pos); - // If there is no newline, removing the last line yields empty. + return (pos == std::string::npos) ? std::string(s) : s.substr(0, pos); } // Returns the 8 transformations as an array of strings. diff --git a/flutter/cpp/datasets/ifeval_utils/irregular-plurals.h b/flutter/cpp/datasets/ifeval_utils/irregular-plurals.h new file mode 100644 index 000000000..0b17d09e7 --- /dev/null +++ b/flutter/cpp/datasets/ifeval_utils/irregular-plurals.h @@ -0,0 +1,339 @@ +// generated from super-duper-clean-irregular-plurals.json + +#ifndef MLPERF_DATASETS_IFEVAL_UTILS_IRREGULAR_PLURALS_H_ +#define MLPERF_DATASETS_IFEVAL_UTILS_IRREGULAR_PLURALS_H_ + +#include +#include + +namespace mlperf { +namespace mobile { +namespace ifeval { + +const std::unordered_map pluralMap = { + {"abscissa", "abscissae"}, + {"addendum", "addenda"}, + {"agendum", "agenda"}, + {"alga", "algae"}, + {"alumna", "alumnae"}, + {"alumnus", "alumni"}, + {"alveolus", "alveoli"}, + {"analysis", "analyses"}, + {"antithesis", "antitheses"}, + {"aphelion", "aphelia"}, + {"axis", "axes"}, + {"bacillus", "bacilli"}, + {"bacterium", "bacteria"}, + {"baculum", "bacula"}, + {"basis", "bases"}, + {"businessman", "businessmen"}, + {"calf", "calves"}, + {"candelabrum", "candelabra"}, + {"chairman", "chairmen"}, + {"child", "children"}, + {"cloaca", "cloacae"}, + {"codex", "codices"}, + {"consortium", "consortia"}, + {"corpus", "corpora"}, + {"cortex", "cortices"}, + {"cranium", "crania"}, + {"crisis", "crises"}, + {"criterion", "criteria"}, + {"curriculum", "curricula"}, + {"cystoma", "cystomata"}, + {"datum", "data"}, + {"desideratum", "desiderata"}, + {"diagnosis", "diagnoses"}, + {"dictum", "dicta"}, + {"die", "dice"}, + {"djinni", "djinn"}, + {"dogma", "dogmata"}, + {"elf", "elves"}, + {"ellipsis", "ellipses"}, + {"emphasis", "emphases"}, + {"emporium", "emporia"}, + {"encomium", "encomia"}, + {"ephemeris", "ephemerides"}, + {"erratum", "errata"}, + {"extremum", "extrema"}, + {"fez", "fezzes"}, + {"fibula", "fibulae"}, + {"foot", "feet"}, + {"foramen", "foramina"}, + {"fungus", "fungi"}, + {"ganglion", "ganglia"}, + {"gentleman", "gentlemen"}, + {"genus", "genera"}, + {"glomerulus", "glomeruli"}, + {"goose", "geese"}, + {"goy", "goyim"}, + {"graffito", "graffiti"}, + {"gumma", "gummata"}, + {"half", "halves"}, + {"hamulus", "hamuli"}, + {"honorarium", "honoraria"}, + {"hoof", "hooves"}, + {"humerus", "humeri"}, + {"hyperbaton", "hyperbata"}, + {"hyperbola", "hyperbolae"}, + {"hypothesis", "hypotheses"}, + {"ilium", "ilia"}, + {"incubus", "incubi"}, + {"interregnum", "interregna"}, + {"interstitium", "interstitia"}, + {"knife", "knives"}, + {"larva", "larvae"}, + {"leaf", "leaves"}, + {"life", "lives"}, + {"loaf", "loaves"}, + {"loculus", "loculi"}, + {"locus", "loci"}, + {"looey", "looies"}, + {"louse", "lice"}, + {"lumen", "lumina"}, + {"lustrum", "lustra"}, + {"lymphoma", "lymphomata"}, + {"man", "men"}, + {"matrix", "matrices"}, + {"maximum", "maxima"}, + {"medium", "media"}, + {"memorandum", "memoranda"}, + {"meniscus", "menisci"}, + {"millennium", "millennia"}, + {"minimum", "minima"}, + {"minutia", "minutiae"}, + {"momentum", "momenta"}, + {"mouse", "mice"}, + {"murex", "murices"}, + {"mythos", "mythoi"}, + {"nemesis", "nemeses"}, + {"neurosis", "neuroses"}, + {"noumenon", "noumena"}, + {"nucleolus", "nucleoli"}, + {"nucleus", "nuclei"}, + {"oasis", "oases"}, + {"occiput", "occipita"}, + {"omphalos", "omphaloi"}, + {"optimum", "optima"}, + {"ovum", "ova"}, + {"ox", "oxen"}, + {"paralysis", "paralyses"}, + {"parenthesis", "parentheses"}, + {"passerby", "passersby"}, + {"perihelion", "perihelia"}, + {"person", "people"}, + {"phalanx", "phalanges"}, + {"phenomenon", "phenomena"}, + {"phylum", "phyla"}, + {"policeman", "policemen"}, + {"polyhedron", "polyhedra"}, + {"pontifex", "pontifices"}, + {"prognosis", "prognoses"}, + {"prolegomenon", "prolegomena"}, + {"quantum", "quanta"}, + {"quiz", "quizzes"}, + {"radius", "radii"}, + {"sarcophagus", "sarcophagi"}, + {"scarf", "scarves"}, + {"scrotum", "scrota"}, + {"self", "selves"}, + {"shelf", "shelves"}, + {"silex", "silices"}, + {"simulacrum", "simulacra"}, + {"spokesman", "spokesmen"}, + {"spectrum", "spectra"}, + {"speculum", "specula"}, + {"stimulus", "stimuli"}, + {"stratum", "strata"}, + {"succubus", "succubi"}, + {"syconium", "syconia"}, + {"synopsis", "synopses"}, + {"synthesis", "syntheses"}, + {"testis", "testes"}, + {"that", "those"}, + {"thesis", "theses"}, + {"thief", "thieves"}, + {"this", "these"}, + {"thrombus", "thrombi"}, + {"tooth", "teeth"}, + {"torus", "tori"}, + {"trapezium", "trapezia"}, + {"umbilicus", "umbilici"}, + {"velum", "vela"}, + {"vertebra", "vertebrae"}, + {"vertex", "vertices"}, + {"viscus", "viscera"}, + {"vita", "vitae"}, + {"vortex", "vortices"}, + {"wharf", "wharves"}, + {"wife", "wives"}, + {"wolf", "wolves"}, + {"woman", "women"}, +}; + +const std::unordered_map singularMap = { + {"abscissae", "abscissa"}, + {"addenda", "addendum"}, + {"agenda", "agendum"}, + {"algae", "alga"}, + {"alumnae", "alumna"}, + {"alumni", "alumnus"}, + {"alveoli", "alveolus"}, + {"analyses", "analysis"}, + {"antitheses", "antithesis"}, + {"aphelia", "aphelion"}, + {"axes", "axis"}, + {"bacilli", "bacillus"}, + {"bacteria", "bacterium"}, + {"bacula", "baculum"}, + {"bases", "basis"}, + {"businessmen", "businessman"}, + {"calves", "calf"}, + {"candelabra", "candelabrum"}, + {"chairmen", "chairman"}, + {"children", "child"}, + {"cloacae", "cloaca"}, + {"codices", "codex"}, + {"consortia", "consortium"}, + {"corpora", "corpus"}, + {"cortices", "cortex"}, + {"crania", "cranium"}, + {"crises", "crisis"}, + {"criteria", "criterion"}, + {"curricula", "curriculum"}, + {"cystomata", "cystoma"}, + {"data", "datum"}, + {"desiderata", "desideratum"}, + {"diagnoses", "diagnosis"}, + {"dicta", "dictum"}, + {"dice", "die"}, + {"djinn", "djinni"}, + {"dogmata", "dogma"}, + {"elves", "elf"}, + {"ellipses", "ellipsis"}, + {"emphases", "emphasis"}, + {"emporia", "emporium"}, + {"encomia", "encomium"}, + {"ephemerides", "ephemeris"}, + {"errata", "erratum"}, + {"extrema", "extremum"}, + {"fezzes", "fez"}, + {"fibulae", "fibula"}, + {"feet", "foot"}, + {"foramina", "foramen"}, + {"fungi", "fungus"}, + {"ganglia", "ganglion"}, + {"gentlemen", "gentleman"}, + {"genera", "genus"}, + {"glomeruli", "glomerulus"}, + {"geese", "goose"}, + {"goyim", "goy"}, + {"graffiti", "graffito"}, + {"gummata", "gumma"}, + {"halves", "half"}, + {"hamuli", "hamulus"}, + {"honoraria", "honorarium"}, + {"hooves", "hoof"}, + {"humeri", "humerus"}, + {"hyperbata", "hyperbaton"}, + {"hyperbolae", "hyperbola"}, + {"hypotheses", "hypothesis"}, + {"ilia", "ilium"}, + {"incubi", "incubus"}, + {"interregna", "interregnum"}, + {"interstitia", "interstitium"}, + {"knives", "knife"}, + {"larvae", "larva"}, + {"leaves", "leaf"}, + {"lives", "life"}, + {"loaves", "loaf"}, + {"loculi", "loculus"}, + {"loci", "locus"}, + {"looies", "looey"}, + {"lice", "louse"}, + {"lumina", "lumen"}, + {"lustra", "lustrum"}, + {"lymphomata", "lymphoma"}, + {"men", "man"}, + {"matrices", "matrix"}, + {"maxima", "maximum"}, + {"media", "medium"}, + {"memoranda", "memorandum"}, + {"menisci", "meniscus"}, + {"millennia", "millennium"}, + {"minima", "minimum"}, + {"minutiae", "minutia"}, + {"momenta", "momentum"}, + {"mice", "mouse"}, + {"murices", "murex"}, + {"mythoi", "mythos"}, + {"nemeses", "nemesis"}, + {"neuroses", "neurosis"}, + {"noumena", "noumenon"}, + {"nucleoli", "nucleolus"}, + {"nuclei", "nucleus"}, + {"oases", "oasis"}, + {"occipita", "occiput"}, + {"omphaloi", "omphalos"}, + {"optima", "optimum"}, + {"ova", "ovum"}, + {"oxen", "ox"}, + {"paralyses", "paralysis"}, + {"parentheses", "parenthesis"}, + {"passersby", "passerby"}, + {"perihelia", "perihelion"}, + {"people", "person"}, + {"phalanges", "phalanx"}, + {"phenomena", "phenomenon"}, + {"phyla", "phylum"}, + {"policemen", "policeman"}, + {"polyhedra", "polyhedron"}, + {"pontifices", "pontifex"}, + {"prognoses", "prognosis"}, + {"prolegomena", "prolegomenon"}, + {"quanta", "quantum"}, + {"quizzes", "quiz"}, + {"radii", "radius"}, + {"sarcophagi", "sarcophagus"}, + {"scarves", "scarf"}, + {"scrota", "scrotum"}, + {"selves", "self"}, + {"shelves", "shelf"}, + {"silices", "silex"}, + {"simulacra", "simulacrum"}, + {"spokesmen", "spokesman"}, + {"spectra", "spectrum"}, + {"specula", "speculum"}, + {"stimuli", "stimulus"}, + {"strata", "stratum"}, + {"succubi", "succubus"}, + {"syconia", "syconium"}, + {"synopses", "synopsis"}, + {"syntheses", "synthesis"}, + {"testes", "testis"}, + {"those", "that"}, + {"theses", "thesis"}, + {"thieves", "thief"}, + {"these", "this"}, + {"thrombi", "thrombus"}, + {"teeth", "tooth"}, + {"tori", "torus"}, + {"trapezia", "trapezium"}, + {"umbilici", "umbilicus"}, + {"vela", "velum"}, + {"vertebrae", "vertebra"}, + {"vertices", "vertex"}, + {"viscera", "viscus"}, + {"vitae", "vita"}, + {"vortices", "vortex"}, + {"wharves", "wharf"}, + {"wives", "wife"}, + {"wolves", "wolf"}, + {"women", "woman"}, +}; + +} // namespace ifeval +} // namespace mobile +} // namespace mlperf + +#endif // MLPERF_DATASETS_IFEVAL_UTILS_IRREGULAR_PLURALS_H_ diff --git a/flutter/cpp/datasets/ifeval_utils/types.h b/flutter/cpp/datasets/ifeval_utils/types.h index 5f7f020c7..4bc4cafba 100644 --- a/flutter/cpp/datasets/ifeval_utils/types.h +++ b/flutter/cpp/datasets/ifeval_utils/types.h @@ -3,13 +3,18 @@ #include #include +#include +#include #include #include #include +#include #include #include "compact_lang_det.h" +#include "english_stem.h" #include "flutter/cpp/datasets/ifeval_utils/common.h" +#include "flutter/cpp/datasets/ifeval_utils/irregular-plurals.h" #include "flutter/cpp/datasets/ifeval_utils/json.h" namespace mlperf { @@ -145,8 +150,7 @@ class RepeatPrompt : public Instruction { private: std::string prompt_; virtual bool verify_(const std::string& resp) const override { - // TODO replace with startswith? - return contains_string(resp, prompt_); + return starts_with(resp, prompt_, 3); } }; @@ -160,6 +164,11 @@ class TwoResponses : public Instruction { std::size_t count = 0; std::size_t pos = resp.find("******"); while (pos != std::string::npos) { + if (pos == 0 || + pos == resp.size() - 6) { // ignore indicators at the start and end + pos = resp.find("******", pos + 6); + continue; + } if (++count > 1) return false; // more than one occurrence pos = resp.find("******", pos + 6); // disallow overlapping matches } @@ -180,24 +189,14 @@ class NumberPlaceholders : public Instruction { std::size_t count = 0, pos = 0; while (pos < resp.length() && (int)count < n_) { // no need to keep looking if the requirement is - // already satisfied + // already satisfied std::size_t open = resp.find('[', pos); if (open == std::string::npos) break; std::size_t close = resp.find(']', open + 1); if (close == std::string::npos) break; - if (close > open + 1) { // non-empty inner - const std::string inner = resp.substr(open + 1, close - open - 1); - bool ok = true; - for (unsigned char ch : inner) { - if (std::isspace(ch) || !(std::isalnum(ch) || ch == '_')) { - ok = false; - break; - } - } - if (ok) ++count; - } - pos = close + 1; // continue after this closing bracket + if (close > open + 1) ++count; // non-empty inner + pos = close + 1; // continue after this closing bracket } return (int)count >= n_; } @@ -224,9 +223,18 @@ class ConstrainedResponse : public Instruction { constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } private: + // TODO constexpr? + const std::string aYes = "My answer is yes."; + const std::string aNo = "My answer is no."; + const std::string aMaybe = "My answer is maybe."; + const unsigned sizeThreshold = 3; virtual bool verify_(const std::string& resp) const override { - return resp == "My answer is yes." || resp == "My answer is no." || - resp == "My answer is maybe."; + return (resp.find(aYes) != std::string::npos && + resp.size() <= sizeThreshold + aYes.size()) || + (resp.find(aNo) != std::string::npos && + resp.size() <= sizeThreshold + aNo.size()) || + (resp.find(aMaybe) != std::string::npos && + resp.size() <= sizeThreshold + aMaybe.size()); } }; @@ -239,6 +247,14 @@ class JsonFormat : public Instruction { virtual bool verify_(const std::string& resp) const override { std::string t = resp; if (t.empty()) return false; + if (t[0] == '`') { + size_t first = t.find('\n'); + size_t last = t.rfind('\n'); + + if (first != std::string::npos && last != std::string::npos && + last > first) + t = t.substr(first + 1, last - first - 1); + } crow::json::rvalue jv = crow::json::load(t); return jv.is_valid(); } @@ -259,25 +275,42 @@ class MultipleSections : public Instruction { if (!trim(p).empty()) ++c; return c; } + + static bool isnum(const std::string text, size_t pos) { + unsigned char c = text[pos]; + return (c >= '0' && c <= '9') || c == 'I' || c == 'V' || c == 'X'; + } + inline std::vector SplitByDelim(const std::string& s, const std::string& delim) const { if (delim.empty()) return {s}; std::vector parts; - size_t start = 0; + size_t start = s.find(delim); while (true) { - size_t pos = s.find(delim, start); + if (start == std::string::npos) break; + size_t pos = s.find(delim, start + delim.size()); if (pos == std::string::npos) { parts.push_back(s.substr(start)); break; } + if (!isnum(s, pos + delim.size() + + 1)) { // just a word, not "Section X", ignore and move + // on to the next one + start = pos; + continue; + } parts.push_back(s.substr(start, pos - start)); - start = pos + delim.size(); + start = pos; } return parts; } virtual bool verify_(const std::string& resp) const override { auto parts = SplitByDelim(resp, sep_); - return CountNonEmpty(parts) == n_; + int count = CountNonEmpty(parts); + if (resp.find("******") != std::string::npos) + count /= 2; // If 2 responses are given, divide by 2 so we get the result + // for each response + return count == n_; } }; @@ -301,7 +334,7 @@ class NumberBulletLists : public Instruction { size_t count = 0; for (const auto& line : SplitLines(resp)) { std::string t = trim(line); - if (t.rfind("* ", 0) == 0) { + if (t.rfind("* ", 0) == 0 || t.rfind("- ", 0) == 0) { ++count; continue; } @@ -406,62 +439,132 @@ class Frequency : public Instruction { int n_; std::string kw_; Relation rel_; + mutable stemming::english_stem<> stemmer; - static inline std::string RegexEscape(const std::string& s) { - auto is_meta = [](unsigned char ch) { - switch (ch) { - case '^': - case '$': - case '.': - case '|': - case '?': - case '*': - case '+': - case '(': - case ')': - case '[': - case ']': - case '{': - case '}': - case '\\': - return true; - default: - return false; - } - }; + std::wstring to_wstring_utf8(const std::string& s) const { + std::wstring_convert> conv; + return conv.from_bytes(s); + } - std::string out; - out.reserve(s.size() * 2); - for (unsigned char c : s) { - if (is_meta(c)) out.push_back('\\'); - out.push_back(static_cast(c)); - } - return out; + std::string to_string_utf8(const std::wstring& ws) const { + std::wstring_convert> conv; + return conv.to_bytes(ws); } - // Build a regex that matches the keyword with custom token boundaries. - // Left boundary is (^|[^A-Za-z0-9_]) to avoid lookbehind. - // Right boundary uses a lookahead (?=$|[^A-Za-z0-9_]). - static inline std::regex MakeKeywordRegex(const std::string& keyword) { - const std::string kw = RegexEscape(keyword); - const std::string pat = - "(^|[^A-Za-z0-9_])" // left boundary (consumes 1 char or start) - "(?:" + - kw + - ")" // keyword literal - "(?=$|[^A-Za-z0-9_])"; // right boundary (zero-width lookahead) - return std::regex(pat, std::regex::icase); + inline std::string getStem(const std::string& word) const { + std::wstring wwordStem(to_wstring_utf8(word)); + stemmer(wwordStem); + std::string wordStem(to_string_utf8(wwordStem)); + return wordStem; } - static inline std::size_t CountKeywordOccurrences( - const std::string& text, const std::string& keyword) { - const std::regex rx = MakeKeywordRegex(keyword); - std::size_t count = 0; - for (auto it = std::sregex_iterator(text.begin(), text.end(), rx), - end = std::sregex_iterator(); - it != end; ++it) { - ++count; + static inline std::string getIrregularPlural(const std::string& word) { + auto it = pluralMap.find(word); + return it != pluralMap.end() ? it->second : word; + } + + static inline std::string getIrregularSingular(const std::string& word) { + auto it = singularMap.find(word); + return it != singularMap.end() ? it->second : word; + } + + // FIXME this potentially doesn't count "try" if the keyword is "trying", + // solution involves stemming the entire text + inline std::size_t CountKeywordOccurrences(const std::string& text, + const std::string& keyword) const { + size_t count{0}; + bool hasStem{false}, stemSubstring{false}, hasPlural{false}, + hasSingular{false}; + + std::string keyword_base = tolower(keyword); + std::string keyword_stem = getStem(keyword_base); + std::string keyword_plural = getIrregularPlural(keyword_base); + std::string keyword_singular; + hasStem = keyword_stem != keyword_base; + stemSubstring = keyword_base.find(keyword_stem) != std::string::npos; + // if the irregular plural can be stemmed to the keyword or vice versa, it + // should be handled by the stemming logic + hasPlural = + keyword_plural != keyword && getStem(keyword_plural) != keyword_stem; + if (!hasPlural) { + keyword_singular = getIrregularSingular(keyword_base); + hasSingular = keyword_singular != keyword_base && + getStem(keyword_singular) != keyword_stem; + } + std::string search_keyword = stemSubstring ? keyword_stem : keyword_base; + + size_t pos = 0; + std::string found; + // count keywords by matching the smallest possible substring of the + // keyword, expanding it, and comparing to possible variants. + while ((pos = find_containing_word(text, search_keyword, found, pos)) != + std::string::npos) { + bool match = false; + found = tolower(found); + // Exact match, Hooray! + if (found == keyword_base) match = true; + std::string foundStem = getStem(found); + // stem match to original keyword (looking for "word", found "words" or + // "wording") + if (!match && foundStem == keyword_base) match = true; + if (!match && hasStem && stemSubstring) { + // match to stemmed keyword (original keyword is "words", found "word") + if (found == keyword_stem) match = true; + // stem match to stemmed keyword (original keyword is "words", found + // "wording") + else if (foundStem != found && foundStem == keyword_stem) + match = true; + } + + if (match) count++; + pos += found.size(); + } + // the stem's lettering differs from the original (words that end with 'y') + if (hasStem && !stemSubstring) { + pos = 0; + while ((pos = find_containing_word(text, keyword_stem, found, pos)) != + std::string::npos) { + found = tolower(found); + // stem match to stemmed keyword (original keyword is "try" (stemmed to + // "tri"), found "tries") since this loop only runs if stem differs from + // the keyword, we can safely assume no overlap occurs with the first + // loop. + if (getStem(found) == keyword_stem) count++; + pos += found.size(); + } + } + // count instances of irregular plurals not caught by the first loop + // this assumes that the plural doesn't stem to the original word (plural + // "children" isn't irregular since it stems to kw "child") + if (hasPlural) { + pos = 0; + while ((pos = find_containing_word(text, keyword_plural, found, pos)) != + std::string::npos) { + found = tolower(found); + // match to pluralized keyword (original keyword is "mouse", found + // "mice") + if (found == keyword_plural) count++; + pos += found.size(); + + // plural match to pluralized keyword is the same as an exact match. + } } + // count instances of irregular singulars not caught by the first loop + // this assumes that the keyword doesn't stem to the singular (kw "children" + // isn't irregular since it stems to singular "child") + if (hasSingular) { + pos = 0; + while ((pos = find_containing_word(text, keyword_singular, found, pos)) != + std::string::npos) { + found = tolower(found); + // match to singular keyword (original keyword is "mice", found "mouse") + if (found == keyword_singular) count++; + pos += found.size(); + + // singular match to singularized keyword is the same as an exact match. + } + } + return count; } @@ -537,8 +640,12 @@ class NthParagraphFirstWord : public Instruction { static std::string FirstWord(const std::string& s) { std::istringstream is(s); std::string w; + std::string fw; is >> w; - return tolower(w); + w = tolower(w); + for (char c : w) + if (std::isalpha(c) && !std::isspace(c)) fw.push_back(c); + return fw; } static inline std::vector SplitParagraphs(const std::string& s) { @@ -576,10 +683,12 @@ class NumberParagraphs : public Instruction { private: unsigned n_; + static constexpr unsigned threshold = + 5; // to allow 5 characters at the very start or end of the response virtual bool verify_(const std::string& resp) const override { std::size_t count = 0, pos = 0; - while ((pos = resp.find("***\n", pos)) != std::string::npos) { - ++count; + while ((pos = resp.find("***", pos)) != std::string::npos) { + if (pos >= threshold && pos <= resp.size() - (3 + threshold)) ++count; pos += 4; // advance by 3 for non-overlapping matches } return count == n_ - 1; // since *** is a saparator, the actual count is 1 @@ -596,10 +705,148 @@ class NumberSentences : public Instruction { private: int n_; Relation rel_; + + inline std::string word_before_dot(const std::string& s, size_t i) const { + size_t start = i; + while (start > 0 && std::isalpha((unsigned char)s[start - 1])) start--; + return s.substr(start, i - start); + } + + inline std::string word_after_dot(const std::string& s, size_t i) const { + size_t end = i + 1; + while (end < s.size() && std::isalpha((unsigned char)s[end])) end++; + return s.substr(i + 1, end - (i + 1)); + } + + inline bool is_letter(char c) const { return std::isalpha(c); } + + inline bool is_digit(char c) const { return c >= '0' && c <= '9'; } + + inline bool is_mark(char c) const { return c == '.' || c == '!' || c == '?'; } + + bool is_initialism(const std::string& s, size_t i) const { + size_t j = i; + unsigned count = 0; + + while (j > 0 && std::isupper((unsigned char)s[j - 1])) { + if (j + 1 < s.size() && s[j] == '.') { + count++; + j -= 2; + } else { + break; + } + } + + // check if followed by another X. for first '.' + if (count == 1) { + if (i + 2 < s.size() && std::isupper((unsigned char)s[i + 1]) && + s[i + 2] == '.') { + count = 2; + } + } + + return count >= 2; + } + + bool is_latin_abbrev(const std::string& s, size_t i) const { + if (i < 3) return false; + return std::islower((unsigned char)s[i - 3]) && s[i - 2] == '.' && + std::islower((unsigned char)s[i - 1]) && s[i] == '.'; + } + + bool is_title_abbrev(const std::string& s, size_t i) const { + static const std::unordered_set titles = { + "Mr", "Mrs", "Ms", "Dr", "Prof", "Sr", "Jr"}; + + std::string word = word_before_dot(s, i); + return !word.empty() && titles.count(word) != 0; + } + + bool is_enumeration_prefix(const std::string& s, size_t i) const { + if (i == 0) return false; + + // Must be followed by space or newline + if (i + 1 >= s.size() || (s[i + 1] != ' ' && s[i + 1] != '\n')) + return false; + + size_t start = i - 1; + + // ---- Numeric enumeration: 1. / 10. ---- + if (is_digit(s[start])) { + while (start > 0 && is_digit(s[start - 1])) start--; + } + + // TODO roman numerals maybe? + // ---- Letter enumeration: a. / A. ---- + else if (is_letter(s[start]) && start > 0 && is_letter(s[start - 1])) + return false; + + // General check + if (start == 0) return true; + + char prev = s[start - 1]; + if (prev == ' ' || prev == '\n' || is_mark(prev)) return true; + + return false; + } + + bool is_domain_suffix(const std::string& s, size_t i) const { + static const std::unordered_set tlds = { + "com", "net", "org", "io", "gov", "edu", "me"}; + + if (i + 1 >= s.size()) return false; + + std::string suffix = word_after_dot(s, i); + return tlds.count(suffix) != 0; + } + + bool is_decimal_point(const std::string& s, size_t i) const { + // digit '.' digit + if (i == 0 || i + 1 >= s.size()) return false; + return is_digit(s[i - 1]) && is_digit(s[i + 1]); + } + + bool is_abbreviation(const std::string& s, size_t i) const { + return is_initialism(s, i) || is_latin_abbrev(s, i) || + is_title_abbrev(s, i); + } + + bool abbreviation_blocks_sentence(const std::string& s, size_t i) const { + if (!is_abbreviation(s, i)) return false; + + // skip spaces + size_t j = i + 1; + while (j < s.size() && s[j] == ' ') j++; + + // If next token is lowercase, it's mid-sentence + if (j < s.size() && std::islower((unsigned char)s[j])) return true; + + return false; + } + + bool ends_sentence(const std::string& s, size_t i) const { + char c = s[i]; + + if (!is_mark(c)) return false; + + // collapse runs ?!... + if (i + 1 < s.size() && is_mark(s[i + 1])) return false; + + if (c == '.') { + if (is_decimal_point(s, i)) return false; + if (is_enumeration_prefix(s, i)) return false; + if (abbreviation_blocks_sentence(s, i)) return false; + if (is_domain_suffix(s, i)) return false; + } + + return true; + } + virtual bool verify_(const std::string& resp) const override { size_t count = 0; - for (unsigned char c : resp) { - if (c == '.' || c == '!' || c == '?') ++count; + + for (size_t i = 0; i < resp.size(); i++) { + if (ends_sentence(resp, i)) count++; } return compare(count, (size_t)n_, rel_); } @@ -653,7 +900,7 @@ class EndChecker : public Instruction { private: std::string end_; virtual bool verify_(const std::string& resp) const override { - return ends_with(resp, end_); + return ends_with(resp, end_, 3); } }; @@ -664,8 +911,7 @@ class Quotation : public Instruction { private: virtual bool verify_(const std::string& resp) const override { - if (resp.size() < 2) return false; - return resp.front() == '"' && resp.back() == '"'; + return resp.size() >= 2 && resp.front() == '"' && resp.back() == '"'; } }; diff --git a/patches/darts_clone.BUILD b/third_party/darts_clone.BUILD similarity index 100% rename from patches/darts_clone.BUILD rename to third_party/darts_clone.BUILD diff --git a/third_party/oleander_stemming_library.BUILD b/third_party/oleander_stemming_library.BUILD new file mode 100644 index 000000000..8cf99c679 --- /dev/null +++ b/third_party/oleander_stemming_library.BUILD @@ -0,0 +1,15 @@ +licenses(["notice"]) + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "oleander_stemming_library", + hdrs = [ + "src/*.h", + ], + includes = [ + "src" + ] +) diff --git a/patches/sentencepiece.BUILD b/third_party/sentencepiece.BUILD similarity index 100% rename from patches/sentencepiece.BUILD rename to third_party/sentencepiece.BUILD