diff --git a/.gitignore b/.gitignore index e2e90ab..1b3a8f1 100644 --- a/.gitignore +++ b/.gitignore @@ -388,10 +388,6 @@ tags ### VisualStudioCode ### .vscode/* -!.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json .history ### VisualStudio ### diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index de6bc22..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,171 +0,0 @@ -{ - "go.testEnvVars": { - "CC": "gcc", - "CXX": "g++" - }, - "go.toolsEnvVars": { - "CC": "gcc", - "CXX": "g++" - }, - "go.diagnostic.vulncheck": "Imports", - "go.disableConcurrentTests": false, - "go.coverOnSave": true, - "go.coverOnSingleTest": true, - "go.coverOnSingleTestFile": true, - "go.inlayHints.assignVariableTypes": true, - "go.inlayHints.compositeLiteralFields": true, - "go.inlayHints.constantValues": true, - "go.inlayHints.compositeLiteralTypes": false, - "go.inlayHints.parameterNames": true, - "go.useLanguageServer": true, - "go.inferGopath": true, - "go.vetOnSave": "workspace", - "go.coverMode": "default", - "go.addTags": { - "tags": "json,bson,yaml", - "promptForTags": true - }, - "go.coverageDecorator": { - "type": "gutter", - "coveredHighlightColor": "rgba(64,128,128,0.5)", - "uncoveredHighlightColor": "rgba(128,64,64,0.25)", - "coveredGutterStyle": "blockgreen", - "uncoveredGutterStyle": "blockred" - }, - "go.formatTool": "gofumpt", - "go.coverShowCounts": true, - "go.enableCodeLens": { - "runtest": true - }, - "gopls": { - "ui.codelenses": { - "gc_details": true, - "upgrade_dependency": true, - "generate": true, - "regenerate_cgo": true, - "test": true, - "tidy": true, - "run_govulncheck": true - }, - "ui.diagnostic.analyses": { - "asmdecl": true, - "nilness": true, - "assign": true, - "atomic": true, - "atomicalign": true, - "bools": true, - "buildtags": true, - "cgocall": true, - "composites": true, - "copylocks": true, - "buildtag": true, - "errorsas": true, - "fieldalignment": true, - "deepequalerrors": true, - "fillreturns": true, - "fillstruct": true, - "nilfunc": true, - "ifaceassert": true, - "httpresponse": true, - "infertypeargs": true, - "lostcancel": true, - "loopclosure": true, - "printf": true, - "simplifycompositelit": true, - "stdmethods": true, - "shadow": true, - "structtag": true, - "timeformat": true, - "unmarshal": true, - "unreachable": true, - "unusedwrite": true, - "unusedvariable": true, - "unsafeptr": true, - "unusedresult": true, - "embed": true, - "nonewvars": true, - "noresultvalues": true, - "shift": true, - "simplifyrange": true, - "simplifyslice": true, - "sortslice": true, - "stringintconv": true, - "stubmethods": true, - "testinggoroutine": true, - "tests": true, - "undeclaredname": true, - "unusedparams": false - }, - "ui.semanticTokens": true, - "ui.completion.experimentalPostfixCompletions": true, - "ui.completion.usePlaceholders": false, - "ui.completion.matcher": "Fuzzy", - "ui.diagnostic.staticcheck": true, - "ui.diagnostic.annotations": { - "bounds": true, - "escape": true, - "inline": true, - "nil": true - }, - "ui.navigation.importShortcut": "Link", - "ui.noSemanticNumber": true, - "ui.navigation.symbolStyle": "Full", - "ui.noSemanticString": true, - "ui.documentation.linksInHover": false, - "ui.navigation.symbolMatcher": "FastFuzzy", - "ui.documentation.hoverKind": "FullDocumentation", - "ui.documentation.linkTarget": "pkg.go.dev", - "build.experimentalPackageCacheKey": true, - "build.memoryMode": "Normal", - "expandWorkspaceToModule": true, - "gofumpt": true - }, - "go.lintTool": "golangci-lint", - "go.toolsManagement.autoUpdate": true, - "go.coverageOptions": "showCoveredCodeOnly", - "go.survey.prompt": false, - "go.editorContextMenuCommands": { - "removeTags": true, - "fillStruct": true, - "testFile": true, - "testPackage": true, - "generateTestForFile": true, - "generateTestForPackage": true, - "benchmarkAtCursor": true - }, - "go.inlayHints.functionTypeParameters": true, - "go.inlayHints.rangeVariableTypes": true, - "go.installDependenciesWhenBuilding": true, - "go.logging.level": "info", - "go.testExplorer.showOutput": true, - "go.testExplorer.packageDisplayMode": "nested", - "go.testExplorer.showDynamicSubtestsInEditor": true, - "go.terminal.activateEnvironment": true, - "go.testTimeout": "30s", - "go.lintOnSave": "workspace", - "editor.fontLigatures": true, - "editor.formatOnPaste": true, - "editor.formatOnSave": true, - "editor.formatOnType": true, - "editor.cursorBlinking": "smooth", - "git.alwaysSignOff": true, - "git.autofetch": true, - "git.ignoreLimitWarning": true, - "files.eol": "\n", - "files.trimTrailingWhitespace": true, - "explorer.incrementalNaming": "smart", - "explorer.sortOrder": "type", - "files.exclude": { - "**/.idea/": true, - "**/.nuke": true, - "**/.vs": true, - "**/obj": true, - "**/TestResults": true - }, - "debug.autoExpandLazyVariables": true, - "debug.console.closeOnEnd": true, - "debug.console.acceptSuggestionOnEnter": "on", - "debug.allowBreakpointsEverywhere": true, - "debug.console.historySuggestions": true, - "debug.console.collapseIdenticalLines": true, -} diff --git a/.vscode/tasks.json b/.vscode/tasks.json deleted file mode 100644 index 1cdc723..0000000 --- a/.vscode/tasks.json +++ /dev/null @@ -1,76 +0,0 @@ -{ - "version": "2.0.0", - "tasks": [ - { - "label": "go-fastText: Race Test", - "type": "shell", - "command": "go", - "args": [ - "test", - "-race", - "-v", - "-covermode=atomic", - "-timeout", - "1m", - "./..." - ], - "group": "test" - }, - { - "label": "go-fastText: Test", - "type": "shell", - "command": "go", - "args": [ - "test", - "-v", - "-covermode=atomic", - "-timeout", - "30s", - "./..." - ], - "group": "test" - }, - { - "label": "go-fastText: Format", - "type": "shell", - "command": "gofumpt", - "args": [ - "-l", - "-w", - "." - ], - "group": "none", - "problemMatcher": [] - }, - { - "label": "go-fastText: Lint", - "type": "shell", - "command": "golangci-lint", - "args": [ - "run" - ], - "group": "none" - }, - { - "label": "go-fastText: GoMod Tidy", - "type": "shell", - "command": "go", - "args": [ - "mod", - "tidy" - ], - "group": "none", - "problemMatcher": [] - }, - { - "label": "go-fastText: Security Check", - "type": "shell", - "command": "gosec", - "args": [ - "./..." - ], - "group": "none", - "problemMatcher": [] - } - ] -} \ No newline at end of file diff --git a/fasttext.go b/fasttext.go index 28d840b..b5d11f9 100644 --- a/fasttext.go +++ b/fasttext.go @@ -136,7 +136,30 @@ func (handle Model) Wordvec(word string) []float32 { }, ) - defer C.FastText_FreeFloatVector(r) + defer C.FastText_FreeFloatVector(r) + + vectors := make([]float32, r.size) + ptr := (*float32)(unsafe.Pointer(r.data)) + copy(vectors, unsafe.Slice(ptr, r.size)) + return vectors +} + +func (handle Model) Sentencevec(sentence string) []float32 { + var pinner runtime.Pinner + defer pinner.Unpin() + + strData := cStr(sentence) + pinner.Pin(strData) + + r := C.FastText_Sentencevec( + handle.p, + C.FastText_String_t{ + data: strData, + size: C.size_t(len(sentence)), + }, + ) + + defer C.FastText_FreeFloatVector(r) vectors := make([]float32, r.size) ptr := (*float32)(unsafe.Pointer(r.data)) diff --git a/fasttext/include/fasttext.h b/fasttext/include/fasttext.h index acc6efc..e289d9f 100644 --- a/fasttext/include/fasttext.h +++ b/fasttext/include/fasttext.h @@ -33,8 +33,8 @@ namespace fasttext { -class FastText -{ + class FastText + { public: using TrainCallback = std::function; @@ -91,8 +91,8 @@ class FastText inline void getInputVector(Vector &vec, int32_t ind) { - vec.zero(); - addInputVector(vec, ind); + vec.zero(); + addInputVector(vec, ind); } const Args getArgs() const; @@ -115,6 +115,7 @@ class FastText void loadModel(const std::string &filename); + Vector getSentenceVector(std::istream &in) const; void getSentenceVector(std::istream &in, Vector &vec); void quantize(const Args &qargs, const TrainCallback &callback = {}); @@ -147,10 +148,10 @@ class FastText class AbortError : public std::runtime_error { - public: - AbortError() : std::runtime_error("Aborted.") - { - } + public: + AbortError() : std::runtime_error("Aborted.") + { + } }; -}; + }; } // namespace fasttext diff --git a/fasttext/src/fasttext.cc b/fasttext/src/fasttext.cc index 1c94662..8d83ffc 100644 --- a/fasttext/src/fasttext.cc +++ b/fasttext/src/fasttext.cc @@ -24,939 +24,986 @@ namespace fasttext { -constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */ -constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314; + constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */ + constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314; -bool comparePairs(const std::pair &l, const std::pair &r); + bool comparePairs(const std::pair &l, const std::pair &r); -std::shared_ptr FastText::createLoss(std::shared_ptr &output) -{ - loss_name lossName = args_->loss; - switch (lossName) - { - case loss_name::hs: - return std::make_shared(output, getTargetCounts()); - case loss_name::ns: - return std::make_shared(output, args_->neg, getTargetCounts()); - case loss_name::softmax: - return std::make_shared(output); - case loss_name::ova: - return std::make_shared(output); - default: - throw std::runtime_error("Unknown loss"); - } -} - -FastText::FastText() : quant_(false), wordVectors_(nullptr), trainException_(nullptr) -{ -} - -void FastText::addInputVector(Vector &vec, int32_t ind) const -{ - vec.addRow(*input_, ind); -} - -std::shared_ptr FastText::getDictionary() const -{ - return dict_; -} - -const Args FastText::getArgs() const -{ - return *args_.get(); -} - -std::shared_ptr FastText::getInputMatrix() const -{ - if (quant_) + std::shared_ptr FastText::createLoss(std::shared_ptr &output) { - throw std::runtime_error("Can't export quantized matrix"); + loss_name lossName = args_->loss; + switch (lossName) + { + case loss_name::hs: + return std::make_shared(output, getTargetCounts()); + case loss_name::ns: + return std::make_shared(output, args_->neg, getTargetCounts()); + case loss_name::softmax: + return std::make_shared(output); + case loss_name::ova: + return std::make_shared(output); + default: + throw std::runtime_error("Unknown loss"); + } } - assert(input_.get()); - return std::dynamic_pointer_cast(input_); -} - -void FastText::setMatrices(const std::shared_ptr &inputMatrix, - const std::shared_ptr &outputMatrix) -{ - assert(input_->size(1) == output_->size(1)); - input_ = std::dynamic_pointer_cast(inputMatrix); - output_ = std::dynamic_pointer_cast(outputMatrix); - wordVectors_.reset(); - args_->dim = input_->size(1); - - buildModel(); -} - -std::shared_ptr FastText::getOutputMatrix() const -{ - if (quant_ && args_->qout) + FastText::FastText() : quant_(false), wordVectors_(nullptr), trainException_(nullptr) { - throw std::runtime_error("Can't export quantized matrix"); } - assert(output_.get()); - return std::dynamic_pointer_cast(output_); -} -int32_t FastText::getWordId(const std::string &word) const -{ - return dict_->getId(word); -} - -int32_t FastText::getSubwordId(const std::string &subword) const -{ - int32_t h = dict_->hash(subword) % args_->bucket; - return dict_->nwords() + h; -} - -int32_t FastText::getLabelId(const std::string &label) const -{ - int32_t labelId = dict_->getId(label); - if (labelId != -1) + void FastText::addInputVector(Vector &vec, int32_t ind) const { - labelId -= dict_->nwords(); + vec.addRow(*input_, ind); } - return labelId; -} -Vector FastText::getWordVector(const std::string_view word) const -{ - const std::vector &ngrams = dict_->getSubwords(word); - Vector vec(args_->dim); - vec.zero(); + std::shared_ptr FastText::getDictionary() const + { + return dict_; + } - for (int i = 0; i < ngrams.size(); i++) + const Args FastText::getArgs() const { - addInputVector(vec, ngrams[i]); + return *args_.get(); } - if (ngrams.size() > 0) + std::shared_ptr FastText::getInputMatrix() const { - vec.mul(1.0 / ngrams.size()); + if (quant_) + { + throw std::runtime_error("Can't export quantized matrix"); + } + assert(input_.get()); + return std::dynamic_pointer_cast(input_); } - return std::move(vec); -} + void FastText::setMatrices(const std::shared_ptr &inputMatrix, + const std::shared_ptr &outputMatrix) + { + assert(input_->size(1) == output_->size(1)); -void FastText::getWordVector(Vector &vec, std::string_view word) const -{ - const std::vector &ngrams = dict_->getSubwords(word); - vec.zero(); + input_ = std::dynamic_pointer_cast(inputMatrix); + output_ = std::dynamic_pointer_cast(outputMatrix); + wordVectors_.reset(); + args_->dim = input_->size(1); - for (int i = 0; i < ngrams.size(); i++) - { - addInputVector(vec, ngrams[i]); + buildModel(); } - if (ngrams.size() > 0) + std::shared_ptr FastText::getOutputMatrix() const { - vec.mul(1.0 / ngrams.size()); + if (quant_ && args_->qout) + { + throw std::runtime_error("Can't export quantized matrix"); + } + assert(output_.get()); + return std::dynamic_pointer_cast(output_); } -} - -void FastText::getSubwordVector(Vector &vec, const std::string &subword) const -{ - vec.zero(); - int32_t h = dict_->hash(subword) % args_->bucket; - h = h + dict_->nwords(); - addInputVector(vec, h); -} -void FastText::saveVectors(const std::string &filename) -{ - if (!input_ || !output_) + int32_t FastText::getWordId(const std::string &word) const { - throw std::runtime_error("Model never trained"); + return dict_->getId(word); } - std::ofstream ofs(filename); - if (!ofs.is_open()) + + int32_t FastText::getSubwordId(const std::string &subword) const { - throw std::invalid_argument(filename + " cannot be opened for saving vectors!"); + int32_t h = dict_->hash(subword) % args_->bucket; + return dict_->nwords() + h; } - ofs << dict_->nwords() << " " << args_->dim << std::endl; - Vector vec(args_->dim); - for (int32_t i = 0; i < dict_->nwords(); i++) + + int32_t FastText::getLabelId(const std::string &label) const { - std::string word = dict_->getWord(i); - getWordVector(vec, std::string_view(word)); - ofs << word << " " << vec << std::endl; + int32_t labelId = dict_->getId(label); + if (labelId != -1) + { + labelId -= dict_->nwords(); + } + return labelId; } - ofs.close(); -} -void FastText::saveOutput(const std::string &filename) -{ - std::ofstream ofs(filename); - if (!ofs.is_open()) + Vector FastText::getWordVector(const std::string_view word) const { - throw std::invalid_argument(filename + " cannot be opened for saving vectors!"); + const std::vector &ngrams = dict_->getSubwords(word); + Vector vec(args_->dim); + vec.zero(); + + for (int i = 0; i < ngrams.size(); i++) + { + addInputVector(vec, ngrams[i]); + } + + if (ngrams.size() > 0) + { + vec.mul(1.0 / ngrams.size()); + } + + return std::move(vec); } - if (quant_) + + void FastText::getWordVector(Vector &vec, std::string_view word) const { - throw std::invalid_argument("Option -saveOutput is not supported for quantized models."); + const std::vector &ngrams = dict_->getSubwords(word); + vec.zero(); + + for (int i = 0; i < ngrams.size(); i++) + { + addInputVector(vec, ngrams[i]); + } + + if (ngrams.size() > 0) + { + vec.mul(1.0 / ngrams.size()); + } } - int32_t n = (args_->model == model_name::sup) ? dict_->nlabels() : dict_->nwords(); - ofs << n << " " << args_->dim << std::endl; - Vector vec(args_->dim); - for (int32_t i = 0; i < n; i++) + + void FastText::getSubwordVector(Vector &vec, const std::string &subword) const { - std::string word = (args_->model == model_name::sup) ? dict_->getLabel(i) : dict_->getWord(i); vec.zero(); - vec.addRow(*output_, i); - ofs << word << " " << vec << std::endl; + int32_t h = dict_->hash(subword) % args_->bucket; + h = h + dict_->nwords(); + addInputVector(vec, h); } - ofs.close(); -} -bool FastText::checkModel(std::istream &in) -{ - int32_t magic; - in.read((char *)&(magic), sizeof(int32_t)); - if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) + void FastText::saveVectors(const std::string &filename) { - return false; + if (!input_ || !output_) + { + throw std::runtime_error("Model never trained"); + } + std::ofstream ofs(filename); + if (!ofs.is_open()) + { + throw std::invalid_argument(filename + " cannot be opened for saving vectors!"); + } + ofs << dict_->nwords() << " " << args_->dim << std::endl; + Vector vec(args_->dim); + for (int32_t i = 0; i < dict_->nwords(); i++) + { + std::string word = dict_->getWord(i); + getWordVector(vec, std::string_view(word)); + ofs << word << " " << vec << std::endl; + } + ofs.close(); } - in.read((char *)&(version), sizeof(int32_t)); - if (version > FASTTEXT_VERSION) + + void FastText::saveOutput(const std::string &filename) { - return false; + std::ofstream ofs(filename); + if (!ofs.is_open()) + { + throw std::invalid_argument(filename + " cannot be opened for saving vectors!"); + } + if (quant_) + { + throw std::invalid_argument("Option -saveOutput is not supported for quantized models."); + } + int32_t n = (args_->model == model_name::sup) ? dict_->nlabels() : dict_->nwords(); + ofs << n << " " << args_->dim << std::endl; + Vector vec(args_->dim); + for (int32_t i = 0; i < n; i++) + { + std::string word = (args_->model == model_name::sup) ? dict_->getLabel(i) : dict_->getWord(i); + vec.zero(); + vec.addRow(*output_, i); + ofs << word << " " << vec << std::endl; + } + ofs.close(); } - return true; -} -void FastText::signModel(std::ostream &out) -{ - const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32; - const int32_t version = FASTTEXT_VERSION; - out.write((char *)&(magic), sizeof(int32_t)); - out.write((char *)&(version), sizeof(int32_t)); -} - -void FastText::saveModel(const std::string &filename) -{ - std::ofstream ofs(filename, std::ofstream::binary); - if (!ofs.is_open()) + bool FastText::checkModel(std::istream &in) { - throw std::invalid_argument(filename + " cannot be opened for saving!"); + int32_t magic; + in.read((char *)&(magic), sizeof(int32_t)); + if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) + { + return false; + } + in.read((char *)&(version), sizeof(int32_t)); + if (version > FASTTEXT_VERSION) + { + return false; + } + return true; } - if (!input_ || !output_) + + void FastText::signModel(std::ostream &out) { - throw std::runtime_error("Model never trained"); + const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32; + const int32_t version = FASTTEXT_VERSION; + out.write((char *)&(magic), sizeof(int32_t)); + out.write((char *)&(version), sizeof(int32_t)); } - signModel(ofs); - args_->save(ofs); - dict_->save(ofs); - ofs.write((char *)&(quant_), sizeof(bool)); - input_->save(ofs); + void FastText::saveModel(const std::string &filename) + { + std::ofstream ofs(filename, std::ofstream::binary); + if (!ofs.is_open()) + { + throw std::invalid_argument(filename + " cannot be opened for saving!"); + } + if (!input_ || !output_) + { + throw std::runtime_error("Model never trained"); + } + signModel(ofs); + args_->save(ofs); + dict_->save(ofs); + + ofs.write((char *)&(quant_), sizeof(bool)); + input_->save(ofs); - ofs.write((char *)&(args_->qout), sizeof(bool)); - output_->save(ofs); + ofs.write((char *)&(args_->qout), sizeof(bool)); + output_->save(ofs); - ofs.close(); -} + ofs.close(); + } -void FastText::loadModel(const std::string &filename) -{ - std::ifstream ifs(filename, std::ifstream::binary); - if (!ifs.is_open()) + void FastText::loadModel(const std::string &filename) { - throw std::invalid_argument(filename + " cannot be opened for loading!"); + std::ifstream ifs(filename, std::ifstream::binary); + if (!ifs.is_open()) + { + throw std::invalid_argument(filename + " cannot be opened for loading!"); + } + if (!checkModel(ifs)) + { + throw std::invalid_argument(filename + " has wrong file format!"); + } + loadModel(ifs); + ifs.close(); } - if (!checkModel(ifs)) + + std::vector FastText::getTargetCounts() const { - throw std::invalid_argument(filename + " has wrong file format!"); + if (args_->model == model_name::sup) + { + return dict_->getCounts(entry_type::label); + } + else + { + return dict_->getCounts(entry_type::word); + } } - loadModel(ifs); - ifs.close(); -} -std::vector FastText::getTargetCounts() const -{ - if (args_->model == model_name::sup) + void FastText::buildModel() { - return dict_->getCounts(entry_type::label); + auto loss = createLoss(output_); + bool normalizeGradient = (args_->model == model_name::sup); + model_ = std::make_shared(input_, output_, loss, normalizeGradient); } - else + + void FastText::loadModel(std::istream &in) { - return dict_->getCounts(entry_type::word); - } -} + args_ = std::make_shared(); + input_ = std::make_shared(); + output_ = std::make_shared(); + args_->load(in); + if (version == 11 && args_->model == model_name::sup) + { + // backward compatibility: old supervised models do not use char ngrams. + args_->maxn = 0; + } + dict_ = std::make_shared(args_, in); -void FastText::buildModel() -{ - auto loss = createLoss(output_); - bool normalizeGradient = (args_->model == model_name::sup); - model_ = std::make_shared(input_, output_, loss, normalizeGradient); -} + bool quant_input; + in.read((char *)&quant_input, sizeof(bool)); + if (quant_input) + { + quant_ = true; + input_ = std::make_shared(); + } + input_->load(in); -void FastText::loadModel(std::istream &in) -{ - args_ = std::make_shared(); - input_ = std::make_shared(); - output_ = std::make_shared(); - args_->load(in); - if (version == 11 && args_->model == model_name::sup) - { - // backward compatibility: old supervised models do not use char ngrams. - args_->maxn = 0; - } - dict_ = std::make_shared(args_, in); + if (!quant_input && dict_->isPruned()) + { + throw std::invalid_argument("Invalid model file.\n" + "Please download the updated model from www.fasttext.cc.\n" + "See issue #332 on Github for more information.\n"); + } - bool quant_input; - in.read((char *)&quant_input, sizeof(bool)); - if (quant_input) - { - quant_ = true; - input_ = std::make_shared(); - } - input_->load(in); + in.read((char *)&args_->qout, sizeof(bool)); + if (quant_ && args_->qout) + { + output_ = std::make_shared(); + } + output_->load(in); - if (!quant_input && dict_->isPruned()) - { - throw std::invalid_argument("Invalid model file.\n" - "Please download the updated model from www.fasttext.cc.\n" - "See issue #332 on Github for more information.\n"); + buildModel(); } - in.read((char *)&args_->qout, sizeof(bool)); - if (quant_ && args_->qout) + std::tuple FastText::progressInfo(real progress) { - output_ = std::make_shared(); - } - output_->load(in); + double t = utils::getDuration(start_, std::chrono::steady_clock::now()); + double lr = args_->lr * (1.0 - progress); + double wst = 0; - buildModel(); -} + int64_t eta = 2592000; // Default to one month in seconds (720 * 3600) -std::tuple FastText::progressInfo(real progress) -{ - double t = utils::getDuration(start_, std::chrono::steady_clock::now()); - double lr = args_->lr * (1.0 - progress); - double wst = 0; + if (progress > 0 && t >= 0) + { + eta = t * (1 - progress) / progress; + wst = double(tokenCount_) / t / args_->thread; + } - int64_t eta = 2592000; // Default to one month in seconds (720 * 3600) + return std::tuple(wst, lr, eta); + } - if (progress > 0 && t >= 0) + void FastText::printInfo(real progress, real loss, std::ostream &log_stream) { - eta = t * (1 - progress) / progress; - wst = double(tokenCount_) / t / args_->thread; - } + double wst; + double lr; + int64_t eta; + std::tie(wst, lr, eta) = progressInfo(progress); - return std::tuple(wst, lr, eta); -} + log_stream << std::fixed; + log_stream << "Progress: "; + log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%"; + log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst); + log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr; + log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss; + log_stream << " ETA: " << utils::ClockPrint(eta); + log_stream << std::flush; + } -void FastText::printInfo(real progress, real loss, std::ostream &log_stream) -{ - double wst; - double lr; - int64_t eta; - std::tie(wst, lr, eta) = progressInfo(progress); - - log_stream << std::fixed; - log_stream << "Progress: "; - log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%"; - log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst); - log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr; - log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss; - log_stream << " ETA: " << utils::ClockPrint(eta); - log_stream << std::flush; -} - -std::vector FastText::selectEmbeddings(int32_t cutoff) const -{ - std::shared_ptr input = std::dynamic_pointer_cast(input_); - Vector norms(input->size(0)); - input->l2NormRow(norms); - std::vector idx(input->size(0), 0); - std::iota(idx.begin(), idx.end(), 0); - auto eosid = dict_->getId(Dictionary::EOS); - std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) { + std::vector FastText::selectEmbeddings(int32_t cutoff) const + { + std::shared_ptr input = std::dynamic_pointer_cast(input_); + Vector norms(input->size(0)); + input->l2NormRow(norms); + std::vector idx(input->size(0), 0); + std::iota(idx.begin(), idx.end(), 0); + auto eosid = dict_->getId(Dictionary::EOS); + std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) + { if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering return false; } - return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]); - }); - idx.erase(idx.begin() + cutoff, idx.end()); - return idx; -} - -void FastText::quantize(const Args &qargs, const TrainCallback &callback) -{ - if (args_->model != model_name::sup) - { - throw std::invalid_argument("For now we only support quantization of supervised models"); + return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]); }); + idx.erase(idx.begin() + cutoff, idx.end()); + return idx; } - args_->input = qargs.input; - args_->qout = qargs.qout; - args_->output = qargs.output; - std::shared_ptr input = std::dynamic_pointer_cast(input_); - std::shared_ptr output = std::dynamic_pointer_cast(output_); - bool normalizeGradient = (args_->model == model_name::sup); - if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) + void FastText::quantize(const Args &qargs, const TrainCallback &callback) { - auto idx = selectEmbeddings(qargs.cutoff); - dict_->prune(idx); - std::shared_ptr ninput = std::make_shared(idx.size(), args_->dim); - for (auto i = 0; i < idx.size(); i++) + if (args_->model != model_name::sup) + { + throw std::invalid_argument("For now we only support quantization of supervised models"); + } + args_->input = qargs.input; + args_->qout = qargs.qout; + args_->output = qargs.output; + std::shared_ptr input = std::dynamic_pointer_cast(input_); + std::shared_ptr output = std::dynamic_pointer_cast(output_); + bool normalizeGradient = (args_->model == model_name::sup); + + if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) { - for (auto j = 0; j < args_->dim; j++) + auto idx = selectEmbeddings(qargs.cutoff); + dict_->prune(idx); + std::shared_ptr ninput = std::make_shared(idx.size(), args_->dim); + for (auto i = 0; i < idx.size(); i++) { - ninput->at(i, j) = input->at(idx[i], j); + for (auto j = 0; j < args_->dim; j++) + { + ninput->at(i, j) = input->at(idx[i], j); + } + } + input = ninput; + if (qargs.retrain) + { + args_->epoch = qargs.epoch; + args_->lr = qargs.lr; + args_->thread = qargs.thread; + args_->verbose = qargs.verbose; + auto loss = createLoss(output_); + model_ = std::make_shared(input, output, loss, normalizeGradient); + startThreads(callback); } } - input = ninput; - if (qargs.retrain) + input_ = std::make_shared(std::move(*(input.get())), qargs.dsub, qargs.qnorm); + + if (args_->qout) { - args_->epoch = qargs.epoch; - args_->lr = qargs.lr; - args_->thread = qargs.thread; - args_->verbose = qargs.verbose; - auto loss = createLoss(output_); - model_ = std::make_shared(input, output, loss, normalizeGradient); - startThreads(callback); + output_ = std::make_shared(std::move(*(output.get())), 2, qargs.qnorm); } + quant_ = true; + auto loss = createLoss(output_); + model_ = std::make_shared(input_, output_, loss, normalizeGradient); } - input_ = std::make_shared(std::move(*(input.get())), qargs.dsub, qargs.qnorm); - - if (args_->qout) - { - output_ = std::make_shared(std::move(*(output.get())), 2, qargs.qnorm); - } - quant_ = true; - auto loss = createLoss(output_); - model_ = std::make_shared(input_, output_, loss, normalizeGradient); -} -void FastText::supervised(Model::State &state, real lr, const std::vector &line, - const std::vector &labels) -{ - if (labels.size() == 0 || line.size() == 0) - { - return; - } - if (args_->loss == loss_name::ova) - { - model_->update(line, labels, Model::kAllLabelsAsTarget, lr, state); - } - else + void FastText::supervised(Model::State &state, real lr, const std::vector &line, + const std::vector &labels) { - std::uniform_int_distribution<> uniform(0, labels.size() - 1); - int32_t i = uniform(state.rng); - model_->update(line, labels, i, lr, state); + if (labels.size() == 0 || line.size() == 0) + { + return; + } + if (args_->loss == loss_name::ova) + { + model_->update(line, labels, Model::kAllLabelsAsTarget, lr, state); + } + else + { + std::uniform_int_distribution<> uniform(0, labels.size() - 1); + int32_t i = uniform(state.rng); + model_->update(line, labels, i, lr, state); + } } -} -void FastText::cbow(Model::State &state, real lr, const std::vector &line) -{ - std::vector bow; - std::uniform_int_distribution<> uniform(1, args_->ws); - for (int32_t w = 0; w < line.size(); w++) + void FastText::cbow(Model::State &state, real lr, const std::vector &line) { - int32_t boundary = uniform(state.rng); - bow.clear(); - for (int32_t c = -boundary; c <= boundary; c++) + std::vector bow; + std::uniform_int_distribution<> uniform(1, args_->ws); + for (int32_t w = 0; w < line.size(); w++) { - if (c != 0 && w + c >= 0 && w + c < line.size()) + int32_t boundary = uniform(state.rng); + bow.clear(); + for (int32_t c = -boundary; c <= boundary; c++) { - const std::vector &ngrams = dict_->getSubwords(line[w + c]); - bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend()); + if (c != 0 && w + c >= 0 && w + c < line.size()) + { + const std::vector &ngrams = dict_->getSubwords(line[w + c]); + bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend()); + } } + model_->update(bow, line, w, lr, state); } - model_->update(bow, line, w, lr, state); } -} -void FastText::skipgram(Model::State &state, real lr, const std::vector &line) -{ - std::uniform_int_distribution<> uniform(1, args_->ws); - for (int32_t w = 0; w < line.size(); w++) + void FastText::skipgram(Model::State &state, real lr, const std::vector &line) { - int32_t boundary = uniform(state.rng); - const std::vector &ngrams = dict_->getSubwords(line[w]); - for (int32_t c = -boundary; c <= boundary; c++) + std::uniform_int_distribution<> uniform(1, args_->ws); + for (int32_t w = 0; w < line.size(); w++) { - if (c != 0 && w + c >= 0 && w + c < line.size()) + int32_t boundary = uniform(state.rng); + const std::vector &ngrams = dict_->getSubwords(line[w]); + for (int32_t c = -boundary; c <= boundary; c++) { - model_->update(ngrams, line, w + c, lr, state); + if (c != 0 && w + c >= 0 && w + c < line.size()) + { + model_->update(ngrams, line, w + c, lr, state); + } } } } -} - -std::tuple FastText::test(std::istream &in, int32_t k, real threshold) -{ - Meter meter(false); - test(in, k, threshold, meter); - return std::tuple(meter.nexamples(), meter.precision(), meter.recall()); -} + std::tuple FastText::test(std::istream &in, int32_t k, real threshold) + { + Meter meter(false); + test(in, k, threshold, meter); -void FastText::test(std::istream &in, int32_t k, real threshold, Meter &meter) const -{ - std::vector line; - std::vector labels; - Predictions predictions; - Model::State state(args_->dim, dict_->nlabels(), 0); - in.clear(); - in.seekg(0, std::ios_base::beg); + return std::tuple(meter.nexamples(), meter.precision(), meter.recall()); + } - while (in.peek() != EOF) + void FastText::test(std::istream &in, int32_t k, real threshold, Meter &meter) const { - line.clear(); - labels.clear(); - dict_->getLine(in, line, labels); + std::vector line; + std::vector labels; + Predictions predictions; + Model::State state(args_->dim, dict_->nlabels(), 0); + in.clear(); + in.seekg(0, std::ios_base::beg); - if (!labels.empty() && !line.empty()) + while (in.peek() != EOF) { - predictions.clear(); - predict(k, line, predictions, threshold); - meter.log(labels, predictions); + line.clear(); + labels.clear(); + dict_->getLine(in, line, labels); + + if (!labels.empty() && !line.empty()) + { + predictions.clear(); + predict(k, line, predictions, threshold); + meter.log(labels, predictions); + } } } -} -void FastText::predict(int32_t k, const std::vector &words, Predictions &predictions, real threshold) const -{ - if (words.empty()) + void FastText::predict(int32_t k, const std::vector &words, Predictions &predictions, real threshold) const { - return; + if (words.empty()) + { + return; + } + Model::State state(args_->dim, dict_->nlabels(), 0); + if (args_->model != model_name::sup) + { + throw std::invalid_argument("Model needs to be supervised for prediction!"); + } + model_->predict(words, k, threshold, predictions, state); } - Model::State state(args_->dim, dict_->nlabels(), 0); - if (args_->model != model_name::sup) + + FullPrediction FastText::predictFull(int32_t k, std::string_view in, real threshold) const { - throw std::invalid_argument("Model needs to be supervised for prediction!"); + std::vector words, labels; + dict_->getStringNoNewline(in, words, labels); + Predictions linePredictions; + linePredictions.reserve(k); + this->predict(k, words, linePredictions, threshold); + + return FullPrediction(std::move(linePredictions), dict_); } - model_->predict(words, k, threshold, predictions, state); -} -FullPrediction FastText::predictFull(int32_t k, std::string_view in, real threshold) const -{ - std::vector words, labels; - dict_->getStringNoNewline(in, words, labels); - Predictions linePredictions; - linePredictions.reserve(k); - this->predict(k, words, linePredictions, threshold); + bool FastText::predictLine(std::istream &in, std::vector> &predictions, int32_t k, + real threshold) const + { + predictions.clear(); + if (in.peek() == EOF) + { + return false; + } - return FullPrediction(std::move(linePredictions), dict_); -} + std::vector words, labels; + dict_->getLine(in, words, labels); + Predictions linePredictions; + predict(k, words, linePredictions, threshold); + for (const auto &p : linePredictions) + { + predictions.push_back(std::make_pair(std::exp(p.first), dict_->getLabel(p.second))); + } -bool FastText::predictLine(std::istream &in, std::vector> &predictions, int32_t k, - real threshold) const -{ - predictions.clear(); - if (in.peek() == EOF) - { - return false; + return true; } - std::vector words, labels; - dict_->getLine(in, words, labels); - Predictions linePredictions; - predict(k, words, linePredictions, threshold); - for (const auto &p : linePredictions) + Vector FastText::getSentenceVector(std::istream &in) const { - predictions.push_back(std::make_pair(std::exp(p.first), dict_->getLabel(p.second))); - } - - return true; -} + Vector svec(args_->dim); + svec.zero(); -void FastText::getSentenceVector(std::istream &in, fasttext::Vector &svec) -{ - svec.zero(); - if (args_->model == model_name::sup) - { - std::vector line, labels; - dict_->getLine(in, line, labels); - for (int32_t i = 0; i < line.size(); i++) + if (args_->model == model_name::sup) { - addInputVector(svec, line[i]); + std::vector line, labels; + dict_->getLine(in, line, labels); + for (int32_t i = 0; i < line.size(); i++) + { + addInputVector(svec, line[i]); + } + if (!line.empty()) + { + svec.mul(1.0 / line.size()); + } } - if (!line.empty()) + else { - svec.mul(1.0 / line.size()); + Vector vec(args_->dim); + std::string sentence; + std::getline(in, sentence); + std::istringstream iss(sentence); + std::string word; + int32_t count = 0; + while (iss >> word) + { + getWordVector(vec, word); + real norm = vec.norm(); + if (norm > 0) + { + vec.mul(1.0 / norm); + svec.addVector(vec); + count++; + } + } + if (count > 0) + { + svec.mul(1.0 / count); + } } + + return std::move(svec); } - else + + void FastText::getSentenceVector(std::istream &in, fasttext::Vector &svec) { - Vector vec(args_->dim); - std::string sentence; - std::getline(in, sentence); - std::istringstream iss(sentence); - std::string word; - int32_t count = 0; - while (iss >> word) + svec.zero(); + if (args_->model == model_name::sup) { - getWordVector(vec, word); - real norm = vec.norm(); - if (norm > 0) + std::vector line, labels; + dict_->getLine(in, line, labels); + for (int32_t i = 0; i < line.size(); i++) { - vec.mul(1.0 / norm); - svec.addVector(vec); - count++; + addInputVector(svec, line[i]); + } + if (!line.empty()) + { + svec.mul(1.0 / line.size()); } } - if (count > 0) + else { - svec.mul(1.0 / count); + Vector vec(args_->dim); + std::string sentence; + std::getline(in, sentence); + std::istringstream iss(sentence); + std::string word; + int32_t count = 0; + while (iss >> word) + { + getWordVector(vec, word); + real norm = vec.norm(); + if (norm > 0) + { + vec.mul(1.0 / norm); + svec.addVector(vec); + count++; + } + } + if (count > 0) + { + svec.mul(1.0 / count); + } } } -} -std::vector> FastText::getNgramVectors(const std::string &word) const -{ - std::vector> result; - std::vector ngrams; - std::vector substrings; - dict_->getSubwords(word, ngrams, substrings); - assert(ngrams.size() <= substrings.size()); - for (int32_t i = 0; i < ngrams.size(); i++) + std::vector> FastText::getNgramVectors(const std::string &word) const { - Vector vec(args_->dim); - if (ngrams[i] >= 0) + std::vector> result; + std::vector ngrams; + std::vector substrings; + dict_->getSubwords(word, ngrams, substrings); + assert(ngrams.size() <= substrings.size()); + for (int32_t i = 0; i < ngrams.size(); i++) { - vec.addRow(*input_, ngrams[i]); + Vector vec(args_->dim); + if (ngrams[i] >= 0) + { + vec.addRow(*input_, ngrams[i]); + } + result.emplace_back(substrings[i], std::move(vec)); } - result.emplace_back(substrings[i], std::move(vec)); + return result; } - return result; -} -void FastText::precomputeWordVectors(DenseMatrix &wordVectors) -{ - Vector vec(args_->dim); - wordVectors.zero(); - for (int32_t i = 0; i < dict_->nwords(); i++) + void FastText::precomputeWordVectors(DenseMatrix &wordVectors) { - std::string word = dict_->getWord(i); - getWordVector(vec, word); - real norm = vec.norm(); - if (norm > 0) + Vector vec(args_->dim); + wordVectors.zero(); + for (int32_t i = 0; i < dict_->nwords(); i++) { - wordVectors.addVectorToRow(vec, i, 1.0 / norm); + std::string word = dict_->getWord(i); + getWordVector(vec, word); + real norm = vec.norm(); + if (norm > 0) + { + wordVectors.addVectorToRow(vec, i, 1.0 / norm); + } } } -} -void FastText::lazyComputeWordVectors() -{ - if (!wordVectors_) + void FastText::lazyComputeWordVectors() { - wordVectors_ = std::unique_ptr(new DenseMatrix(dict_->nwords(), args_->dim)); - precomputeWordVectors(*wordVectors_); + if (!wordVectors_) + { + wordVectors_ = std::unique_ptr(new DenseMatrix(dict_->nwords(), args_->dim)); + precomputeWordVectors(*wordVectors_); + } } -} - -std::vector> FastText::getNN(const std::string &word, int32_t k) -{ - Vector query(args_->dim); - - getWordVector(query, word); - lazyComputeWordVectors(); - assert(wordVectors_); - return getNN(*wordVectors_, query, k, {word}); -} + std::vector> FastText::getNN(const std::string &word, int32_t k) + { + Vector query(args_->dim); -std::vector> FastText::getNN(const DenseMatrix &wordVectors, const Vector &query, - int32_t k, const std::set &banSet) -{ - std::vector> heap; + getWordVector(query, word); - real queryNorm = query.norm(); - if (std::abs(queryNorm) < 1e-8) - { - queryNorm = 1; + lazyComputeWordVectors(); + assert(wordVectors_); + return getNN(*wordVectors_, query, k, {word}); } - for (int32_t i = 0; i < dict_->nwords(); i++) + std::vector> FastText::getNN(const DenseMatrix &wordVectors, const Vector &query, + int32_t k, const std::set &banSet) { - std::string word = dict_->getWord(i); - if (banSet.find(word) == banSet.end()) + std::vector> heap; + + real queryNorm = query.norm(); + if (std::abs(queryNorm) < 1e-8) { - real dp = wordVectors.dotRow(query, i); - real similarity = dp / queryNorm; + queryNorm = 1; + } - if (heap.size() == k && similarity < heap.front().first) - { - continue; - } - heap.push_back(std::make_pair(similarity, word)); - std::push_heap(heap.begin(), heap.end(), comparePairs); - if (heap.size() > k) + for (int32_t i = 0; i < dict_->nwords(); i++) + { + std::string word = dict_->getWord(i); + if (banSet.find(word) == banSet.end()) { - std::pop_heap(heap.begin(), heap.end(), comparePairs); - heap.pop_back(); + real dp = wordVectors.dotRow(query, i); + real similarity = dp / queryNorm; + + if (heap.size() == k && similarity < heap.front().first) + { + continue; + } + heap.push_back(std::make_pair(similarity, word)); + std::push_heap(heap.begin(), heap.end(), comparePairs); + if (heap.size() > k) + { + std::pop_heap(heap.begin(), heap.end(), comparePairs); + heap.pop_back(); + } } } + std::sort_heap(heap.begin(), heap.end(), comparePairs); + + return heap; } - std::sort_heap(heap.begin(), heap.end(), comparePairs); - return heap; -} + std::vector> FastText::getAnalogies(int32_t k, const std::string &wordA, + const std::string &wordB, const std::string &wordC) + { + Vector query(args_->dim); + query.zero(); -std::vector> FastText::getAnalogies(int32_t k, const std::string &wordA, - const std::string &wordB, const std::string &wordC) -{ - Vector query(args_->dim); - query.zero(); - - Vector buffer(args_->dim); - getWordVector(buffer, wordA); - query.addVector(buffer, 1.0 / (buffer.norm() + 1e-8)); - getWordVector(buffer, wordB); - query.addVector(buffer, -1.0 / (buffer.norm() + 1e-8)); - getWordVector(buffer, wordC); - query.addVector(buffer, 1.0 / (buffer.norm() + 1e-8)); - - lazyComputeWordVectors(); - assert(wordVectors_); - return getNN(*wordVectors_, query, k, {wordA, wordB, wordC}); -} - -bool FastText::keepTraining(const int64_t ntokens) const -{ - return tokenCount_ < args_->epoch * ntokens && !trainException_; -} + Vector buffer(args_->dim); + getWordVector(buffer, wordA); + query.addVector(buffer, 1.0 / (buffer.norm() + 1e-8)); + getWordVector(buffer, wordB); + query.addVector(buffer, -1.0 / (buffer.norm() + 1e-8)); + getWordVector(buffer, wordC); + query.addVector(buffer, 1.0 / (buffer.norm() + 1e-8)); -void FastText::trainThread(int32_t threadId, const TrainCallback &callback) -{ - std::ifstream ifs(args_->input); - utils::seek(ifs, threadId * utils::size(ifs) / args_->thread); + lazyComputeWordVectors(); + assert(wordVectors_); + return getNN(*wordVectors_, query, k, {wordA, wordB, wordC}); + } - Model::State state(args_->dim, output_->size(0), threadId + args_->seed); + bool FastText::keepTraining(const int64_t ntokens) const + { + return tokenCount_ < args_->epoch * ntokens && !trainException_; + } - const int64_t ntokens = dict_->ntokens(); - int64_t localTokenCount = 0; - std::vector line, labels; - uint64_t callbackCounter = 0; - try + void FastText::trainThread(int32_t threadId, const TrainCallback &callback) { - while (keepTraining(ntokens)) + std::ifstream ifs(args_->input); + utils::seek(ifs, threadId * utils::size(ifs) / args_->thread); + + Model::State state(args_->dim, output_->size(0), threadId + args_->seed); + + const int64_t ntokens = dict_->ntokens(); + int64_t localTokenCount = 0; + std::vector line, labels; + uint64_t callbackCounter = 0; + try { - real progress = real(tokenCount_) / (args_->epoch * ntokens); - if (callback && ((callbackCounter++ % 64) == 0)) - { - double wst; - double lr; - int64_t eta; - std::tie(wst, lr, eta) = progressInfo(progress); - callback(progress, loss_, wst, lr, eta); - } - real lr = args_->lr * (1.0 - progress); - if (args_->model == model_name::sup) - { - localTokenCount += dict_->getLine(ifs, line, labels); - supervised(state, lr, line, labels); - } - else if (args_->model == model_name::cbow) - { - localTokenCount += dict_->getLine(ifs, line, state.rng); - cbow(state, lr, line); - } - else if (args_->model == model_name::sg) - { - localTokenCount += dict_->getLine(ifs, line, state.rng); - skipgram(state, lr, line); - } - if (localTokenCount > args_->lrUpdateRate) + while (keepTraining(ntokens)) { - tokenCount_ += localTokenCount; - localTokenCount = 0; - if (threadId == 0 && args_->verbose > 1) + real progress = real(tokenCount_) / (args_->epoch * ntokens); + if (callback && ((callbackCounter++ % 64) == 0)) + { + double wst; + double lr; + int64_t eta; + std::tie(wst, lr, eta) = progressInfo(progress); + callback(progress, loss_, wst, lr, eta); + } + real lr = args_->lr * (1.0 - progress); + if (args_->model == model_name::sup) { - loss_ = state.getLoss(); + localTokenCount += dict_->getLine(ifs, line, labels); + supervised(state, lr, line, labels); + } + else if (args_->model == model_name::cbow) + { + localTokenCount += dict_->getLine(ifs, line, state.rng); + cbow(state, lr, line); + } + else if (args_->model == model_name::sg) + { + localTokenCount += dict_->getLine(ifs, line, state.rng); + skipgram(state, lr, line); + } + if (localTokenCount > args_->lrUpdateRate) + { + tokenCount_ += localTokenCount; + localTokenCount = 0; + if (threadId == 0 && args_->verbose > 1) + { + loss_ = state.getLoss(); + } } } } - } - catch (DenseMatrix::EncounteredNaNError &) - { - trainException_ = std::current_exception(); - } - if (threadId == 0) - loss_ = state.getLoss(); - ifs.close(); -} - -std::shared_ptr FastText::getInputMatrixFromFile(const std::string &filename) const -{ - std::ifstream in(filename); - std::vector words; - std::shared_ptr mat; // temp. matrix for pretrained vectors - int64_t n, dim; - if (!in.is_open()) - { - throw std::invalid_argument(filename + " cannot be opened for loading!"); - } - in >> n >> dim; - if (dim != args_->dim) - { - throw std::invalid_argument("Dimension of pretrained vectors (" + std::to_string(dim) + - ") does not match dimension (" + std::to_string(args_->dim) + ")!"); - } - mat = std::make_shared(n, dim); - for (size_t i = 0; i < n; i++) - { - std::string word; - in >> word; - words.push_back(word); - dict_->add(word); - for (size_t j = 0; j < dim; j++) + catch (DenseMatrix::EncounteredNaNError &) { - in >> mat->at(i, j); + trainException_ = std::current_exception(); } + if (threadId == 0) + loss_ = state.getLoss(); + ifs.close(); } - in.close(); - - dict_->threshold(1, 0); - dict_->init(); - std::shared_ptr input = std::make_shared(dict_->nwords() + args_->bucket, args_->dim); - input->uniform(1.0 / args_->dim, args_->thread, args_->seed); - for (size_t i = 0; i < n; i++) + std::shared_ptr FastText::getInputMatrixFromFile(const std::string &filename) const { - int32_t idx = dict_->getId(words[i]); - if (idx < 0 || idx >= dict_->nwords()) + std::ifstream in(filename); + std::vector words; + std::shared_ptr mat; // temp. matrix for pretrained vectors + int64_t n, dim; + if (!in.is_open()) { - continue; + throw std::invalid_argument(filename + " cannot be opened for loading!"); } - for (size_t j = 0; j < dim; j++) + in >> n >> dim; + if (dim != args_->dim) { - input->at(idx, j) = mat->at(i, j); + throw std::invalid_argument("Dimension of pretrained vectors (" + std::to_string(dim) + + ") does not match dimension (" + std::to_string(args_->dim) + ")!"); } - } - return input; -} - -std::shared_ptr FastText::createRandomMatrix() const -{ - std::shared_ptr input = std::make_shared(dict_->nwords() + args_->bucket, args_->dim); - input->uniform(1.0 / args_->dim, args_->thread, args_->seed); - - return input; -} - -std::shared_ptr FastText::createTrainOutputMatrix() const -{ - int64_t m = (args_->model == model_name::sup) ? dict_->nlabels() : dict_->nwords(); - std::shared_ptr output = std::make_shared(m, args_->dim); - output->zero(); + mat = std::make_shared(n, dim); + for (size_t i = 0; i < n; i++) + { + std::string word; + in >> word; + words.push_back(word); + dict_->add(word); + for (size_t j = 0; j < dim; j++) + { + in >> mat->at(i, j); + } + } + in.close(); - return output; -} + dict_->threshold(1, 0); + dict_->init(); + std::shared_ptr input = std::make_shared(dict_->nwords() + args_->bucket, args_->dim); + input->uniform(1.0 / args_->dim, args_->thread, args_->seed); -void FastText::train(const Args &args, const TrainCallback &callback) -{ - args_ = std::make_shared(args); - dict_ = std::make_shared(args_); - if (args_->input == "-") - { - // manage expectations - throw std::invalid_argument("Cannot use stdin for training!"); - } - std::ifstream ifs(args_->input); - if (!ifs.is_open()) - { - throw std::invalid_argument(args_->input + " cannot be opened for training!"); + for (size_t i = 0; i < n; i++) + { + int32_t idx = dict_->getId(words[i]); + if (idx < 0 || idx >= dict_->nwords()) + { + continue; + } + for (size_t j = 0; j < dim; j++) + { + input->at(idx, j) = mat->at(i, j); + } + } + return input; } - dict_->readFromFile(ifs); - ifs.close(); - if (!args_->pretrainedVectors.empty()) + std::shared_ptr FastText::createRandomMatrix() const { - input_ = getInputMatrixFromFile(args_->pretrainedVectors); - } - else - { - input_ = createRandomMatrix(); - } - output_ = createTrainOutputMatrix(); - quant_ = false; - auto loss = createLoss(output_); - bool normalizeGradient = (args_->model == model_name::sup); - model_ = std::make_shared(input_, output_, loss, normalizeGradient); - startThreads(callback); -} + std::shared_ptr input = std::make_shared(dict_->nwords() + args_->bucket, args_->dim); + input->uniform(1.0 / args_->dim, args_->thread, args_->seed); -void FastText::abort() -{ - try - { - throw AbortError(); + return input; } - catch (AbortError &) + + std::shared_ptr FastText::createTrainOutputMatrix() const { - trainException_ = std::current_exception(); + int64_t m = (args_->model == model_name::sup) ? dict_->nlabels() : dict_->nwords(); + std::shared_ptr output = std::make_shared(m, args_->dim); + output->zero(); + + return output; } -} -void FastText::startThreads(const TrainCallback &callback) -{ - start_ = std::chrono::steady_clock::now(); - tokenCount_ = 0; - loss_ = -1; - trainException_ = nullptr; - std::vector threads; - if (args_->thread > 1) + void FastText::train(const Args &args, const TrainCallback &callback) { - for (int32_t i = 0; i < args_->thread; i++) + args_ = std::make_shared(args); + dict_ = std::make_shared(args_); + if (args_->input == "-") + { + // manage expectations + throw std::invalid_argument("Cannot use stdin for training!"); + } + std::ifstream ifs(args_->input); + if (!ifs.is_open()) + { + throw std::invalid_argument(args_->input + " cannot be opened for training!"); + } + dict_->readFromFile(ifs); + ifs.close(); + + if (!args_->pretrainedVectors.empty()) + { + input_ = getInputMatrixFromFile(args_->pretrainedVectors); + } + else { - threads.push_back(std::thread([=]() { trainThread(i, callback); })); + input_ = createRandomMatrix(); } + output_ = createTrainOutputMatrix(); + quant_ = false; + auto loss = createLoss(output_); + bool normalizeGradient = (args_->model == model_name::sup); + model_ = std::make_shared(input_, output_, loss, normalizeGradient); + startThreads(callback); } - else + + void FastText::abort() { - // webassembly can't instantiate `std::thread` - trainThread(0, callback); + try + { + throw AbortError(); + } + catch (AbortError &) + { + trainException_ = std::current_exception(); + } } - const int64_t ntokens = dict_->ntokens(); - // Same condition as trainThread - while (keepTraining(ntokens)) + + void FastText::startThreads(const TrainCallback &callback) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (loss_ >= 0 && args_->verbose > 1) + start_ = std::chrono::steady_clock::now(); + tokenCount_ = 0; + loss_ = -1; + trainException_ = nullptr; + std::vector threads; + if (args_->thread > 1) + { + for (int32_t i = 0; i < args_->thread; i++) + { + threads.push_back(std::thread([=]() + { trainThread(i, callback); })); + } + } + else + { + // webassembly can't instantiate `std::thread` + trainThread(0, callback); + } + const int64_t ntokens = dict_->ntokens(); + // Same condition as trainThread + while (keepTraining(ntokens)) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (loss_ >= 0 && args_->verbose > 1) + { + real progress = real(tokenCount_) / (args_->epoch * ntokens); + std::cerr << "\r"; + printInfo(progress, loss_, std::cerr); + } + } + for (int32_t i = 0; i < threads.size(); i++) + { + threads[i].join(); + } + if (trainException_) + { + std::exception_ptr exception = trainException_; + trainException_ = nullptr; + std::rethrow_exception(exception); + } + if (args_->verbose > 0) { - real progress = real(tokenCount_) / (args_->epoch * ntokens); std::cerr << "\r"; - printInfo(progress, loss_, std::cerr); + printInfo(1.0, loss_, std::cerr); + std::cerr << std::endl; } } - for (int32_t i = 0; i < threads.size(); i++) + + int FastText::getDimension() const { - threads[i].join(); + return args_->dim; } - if (trainException_) + + bool FastText::isQuant() const { - std::exception_ptr exception = trainException_; - trainException_ = nullptr; - std::rethrow_exception(exception); + return quant_; } - if (args_->verbose > 0) + + bool comparePairs(const std::pair &l, const std::pair &r) { - std::cerr << "\r"; - printInfo(1.0, loss_, std::cerr); - std::cerr << std::endl; + return l.first > r.first; } -} - -int FastText::getDimension() const -{ - return args_->dim; -} - -bool FastText::isQuant() const -{ - return quant_; -} - -bool comparePairs(const std::pair &l, const std::pair &r) -{ - return l.first > r.first; -} } // namespace fasttext diff --git a/predictions.cpp b/predictions.cpp index 0930060..cf95b3b 100644 --- a/predictions.cpp +++ b/predictions.cpp @@ -1,4 +1,5 @@ #include +#include #include "predictions.h" @@ -23,7 +24,6 @@ BEGIN_EXTERN_C() FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word) { const auto model = reinterpret_cast(handle); - int64_t dimensions = model->getDimension(); auto vec = new fasttext::Vector(std::move(model->getWordVector(std::string_view(word.data, word.size)))); return FastText_FloatVector_t{ @@ -33,23 +33,18 @@ FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText }; } -// FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentence) -// { -// const auto model = reinterpret_cast(handle); - -// membuf sbuf(sentence); -// std::istream in(&sbuf); - -// auto vec = new fasttext::Vector(model->getDimension()); -// model->getSentenceVector(in, *vec); -// FREE_STRING(sentence); +FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentence) +{ + const auto model = reinterpret_cast(handle); + std::stringstream ss(sentence.data); + auto vec = new fasttext::Vector(std::move(model->getSentenceVector(ss))); -// return FastText_FloatVector_t{ -// vec->data(), -// (void *)vec, -// (size_t)vec->size(), -// }; -// } + return FastText_FloatVector_t{ + vec->data(), + (void *)vec, + (size_t)vec->size(), + }; +} void FastText_FreeFloatVector(FastText_FloatVector_t vector) { diff --git a/predictions.h b/predictions.h index 406ba4d..97ffd1f 100644 --- a/predictions.h +++ b/predictions.h @@ -6,18 +6,18 @@ #define LABEL_PREFIX ("__label__") #define LABEL_PREFIX_SIZE (sizeof(LABEL_PREFIX) - 1) -#define FREE_STRING(str) \ - do \ - { \ - if (str.data != nullptr) \ - free(str.data); \ - str.data = nullptr; \ - str.size = 0; \ +#define FREE_STRING(str) \ + do \ + { \ + if (str.data != nullptr) \ + free(str.data); \ + str.data = nullptr; \ + str.size = 0; \ } while (0) #ifdef __cplusplus -#define BEGIN_EXTERN_C() \ - extern "C" \ +#define BEGIN_EXTERN_C() \ + extern "C" \ { #else #define BEGIN_EXTERN_C() @@ -29,7 +29,8 @@ #define END_EXTERN_C() #endif -BEGIN_EXTERN_C() typedef void *FastText_Handle_t; +BEGIN_EXTERN_C() +typedef void *FastText_Handle_t; typedef struct { @@ -58,7 +59,8 @@ typedef struct ERROR = 1, } status; - union { + union + { FastText_Handle_t handle; char *error; }; @@ -69,6 +71,7 @@ void FastText_DeleteHandle(const FastText_Handle_t handle); size_t FastText_Predict(const FastText_Handle_t handle, FastText_String_t query, uint32_t k, float threshold, FastText_PredictItem_t *const value); FastText_FloatVector_t FastText_Wordvec(const FastText_Handle_t handle, FastText_String_t word); +FastText_FloatVector_t FastText_Sentencevec(const FastText_Handle_t handle, FastText_String_t sentence); void FastText_FreeFloatVector(FastText_FloatVector_t vector);