From c4144cbf23f6b0f363585f22ad70096f7c34f9c3 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Fri, 4 Mar 2022 13:19:19 -0500 Subject: [PATCH 1/2] task: implement kernel fusion over distribution --- .gitignore | 3 + CMakeLists.txt | 2 +- include/taco/codegen/module.h | 11 +- include/taco/index_notation/transformations.h | 13 + include/taco/tensor.h | 2 + src/codegen/module.cpp | 55 ++ src/index_notation/transformations.cpp | 667 ++++++++++++++++++ src/tensor.cpp | 58 ++ tools/taco.cpp | 26 +- 9 files changed, 830 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 16389f34e..0be9e12a1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ CMakeCache.txt doc apps/tensor_times_vector/tensor_times_vector + +.CMakeCache +.vscode diff --git a/CMakeLists.txt b/CMakeLists.txt index a6a80d9d1..74e2af67e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(taco ) option(CUDA "Build for NVIDIA GPU (CUDA must be preinstalled)" OFF) option(PYTHON "Build TACO for python environment" OFF) -option(OPENMP "Build with OpenMP execution support" OFF) +option(OPENMP "Build with OpenMP execution support" ON) option(COVERAGE "Build with code coverage analysis" OFF) set(TACO_FEATURE_CUDA 0) set(TACO_FEATURE_OPENMP 0) diff --git a/include/taco/codegen/module.h b/include/taco/codegen/module.h index 36eb34f1a..38c2bbcf8 100644 --- a/include/taco/codegen/module.h +++ b/include/taco/codegen/module.h @@ -17,7 +17,7 @@ class Module { public: /// Create a module for some target Module(Target target=getTargetFromEnvironment()) - : lib_handle(nullptr), moduleFromUserSource(false), target(target) { + : lib_handle(nullptr), so_lib_handle(nullptr), moduleFromUserSource(false), target(target) { setJITLibname(); setJITTmpdir(); } @@ -44,19 +44,27 @@ class Module { /// before calling. If there's no function of this name then a nullptr is /// returned. void* getFuncPtr(std::string name); + void* getFuncPtr(std::string& sofile, std::string name); /// Call a raw function in this module and return the result int callFuncPackedRaw(std::string name, void** args); + int callFuncPackedRaw(std::string name, std::string&sofile, void** args); /// Call a raw function in this module and return the result int callFuncPackedRaw(std::string name, std::vector args) { return callFuncPackedRaw(name, args.data()); } + int callFuncPackedRaw(std::string name, std::string& sofile, std::vector args) { + return callFuncPackedRaw(name, sofile, args.data()); + } /// Call a function using the taco_tensor_t interface and return the result int callFuncPacked(std::string name, void** args) { return callFuncPackedRaw("_shim_"+name, args); } + int callFuncPacked(std::string name, std::string& sofile, void** args) { + return callFuncPackedRaw("__shim__"+name, sofile, args); + } /// Call a function using the taco_tensor_t interface and return the result int callFuncPacked(std::string name, std::vector args) { @@ -72,6 +80,7 @@ class Module { std::string libname; std::string tmpdir; void* lib_handle; + void* so_lib_handle; // to pass the manually compiles so (shared object) code file std::vector funcs; // true iff the module was created from user-provided source diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index 7aa2579ad..cc511eba7 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -223,6 +223,19 @@ IndexStmt parallelizeOuterLoop(IndexStmt stmt); */ IndexStmt reorderLoopsTopologically(IndexStmt stmt); +/** + * @brief Transform topologically reordered iteration graph to + * a branched version + * + * @param stmt topologically sorted index statement + * @param assignment assignment statement of the base statement + * @param side side in which the operation is performed front/back + * @param iters number of iterations + * @return IndexStmt + */ +IndexStmt loopFusionOverFission(IndexStmt stmt, Assignment assignment, + std::string side, int iters); + /** * Performs scalar promotion so that reductions are done by accumulating into * scalar temporaries whenever possible. diff --git a/include/taco/tensor.h b/include/taco/tensor.h index b91782256..883718fb6 100644 --- a/include/taco/tensor.h +++ b/include/taco/tensor.h @@ -413,6 +413,8 @@ class TensorBase { /// Compile the tensor expression. void compile(); + void compute(std::ofstream& statfile); + void compute(std::ofstream& statfile, std::string& sofile); void compile(IndexStmt stmt, bool assembleWhileCompute=false); diff --git a/src/codegen/module.cpp b/src/codegen/module.cpp index bd0f487b1..7d165e377 100644 --- a/src/codegen/module.cpp +++ b/src/codegen/module.cpp @@ -138,6 +138,12 @@ string Module::compile() { prefix + file_ending + " " + shims_file + " " + "-o " + fullpath + " -lm"; + std::cout << "--- tmpdir: " << tmpdir << std::endl + << "--- libname: " << libname << std::endl + << "--- prefix: " << prefix << std::endl + << "--- fullpath: " << fullpath << std::endl + << "--- cmd: " << cmd << std::endl; + // open the output file & write out the source compileToSource(tmpdir, libname); @@ -172,6 +178,15 @@ void* Module::getFuncPtr(std::string name) { return dlsym(lib_handle, name.data()); } +void* Module::getFuncPtr(std::string& sofile, std::string name) { + std::cout << "opening shared object " << sofile << std::endl; + if (so_lib_handle) { + dlclose(so_lib_handle); + } + so_lib_handle = dlopen(sofile.data(), RTLD_NOW | RTLD_LOCAL); + return dlsym(so_lib_handle, name.data()); +} + int Module::callFuncPackedRaw(std::string name, void** args) { typedef int (*fnptr_t)(void**); static_assert(sizeof(void*) == sizeof(fnptr_t), @@ -210,5 +225,45 @@ int Module::callFuncPackedRaw(std::string name, void** args) { return ret; } +int Module::callFuncPackedRaw(std::string name, std::string& sofile, void** args) { + typedef int (*fnptr_t)(void**); + static_assert(sizeof(void*) == sizeof(fnptr_t), + "Unable to cast dlsym() returned void pointer to function pointer"); + void* v_func_ptr = getFuncPtr(sofile, name); + fnptr_t func_ptr; + *reinterpret_cast(&func_ptr) = v_func_ptr; + +#if USE_OPENMP + omp_sched_t existingSched; + ParallelSchedule tacoSched; + int existingChunkSize, tacoChunkSize; + int existingNumThreads = omp_get_max_threads(); + omp_get_schedule(&existingSched, &existingChunkSize); + taco_get_parallel_schedule(&tacoSched, &tacoChunkSize); + switch (tacoSched) { + case ParallelSchedule::Static: + omp_set_schedule(omp_sched_static, tacoChunkSize); + break; + case ParallelSchedule::Dynamic: + omp_set_schedule(omp_sched_dynamic, tacoChunkSize); + break; + default: + break; + } + omp_set_num_threads(taco_get_num_threads()); +#endif + + std::cout << "calling the function\n"; + int ret = func_ptr(args); + std::cout << "function call completed\n"; + +#if USE_OPENMP + omp_set_schedule(existingSched, existingChunkSize); + omp_set_num_threads(existingNumThreads); +#endif + + return ret; +} + } // namespace ir } // namespace taco diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 47fc1dd55..6a67a3ed2 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -1,5 +1,7 @@ #include "taco/index_notation/transformations.h" +#include "lower/iteration_graph.h" +#include "lower/tensor_path.h" #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" @@ -1321,6 +1323,671 @@ topologicallySort(map> hardDeps, } +bool checkFromBack(const TensorPath& resultTensorPath, + const vector& tensorPaths, + string& removedAccessNode, + vector& producerVars, + vector& consumerVars, + vector& modifiedResultIndexesAccessed, + vector& sortedAllIndexes) { + + std::cout << "check from back function execution\n"; + + const std::vector& resultIndexesVisited = resultTensorPath.getVariables(); + IndexVar lastVisitedIndexVar = resultIndexesVisited.back(); + + std::cout << "last visited index variable: " << lastVisitedIndexVar << std::endl; + + bool onlyLastTensorContainLastIndexOfOutput = true; + bool fissionFromBack = false; + + // check from the back + for (unsigned long i=0; i& indexesVisited = otherIndexPaths.getVariables(); + cout << "index paths: " << otherIndexPaths << endl; + + // if (i < tensorPaths.size()-1) { + // check if other tensors also contain last index of output tensor + for (auto index : indexesVisited) { + cout << "checking " << index << " " << lastVisitedIndexVar << endl; + if (index == lastVisitedIndexVar) { + onlyLastTensorContainLastIndexOfOutput = false; + } + } + // } + } + + if (onlyLastTensorContainLastIndexOfOutput) { // last accessed tensorVariable + const TensorPath& otherIndexPaths = tensorPaths.back(); + const vector& indexesVisited = otherIndexPaths.getVariables(); + cout << "index paths: " << otherIndexPaths << endl; + + cout << "index variable maybe removed from the back\n"; + auto lastTensorLastVisited = indexesVisited.back(); + cout << "last index last visited " << lastTensorLastVisited << endl; + + if (lastTensorLastVisited == lastVisitedIndexVar) { + cout << "we can diffuse from the back\n"; + fissionFromBack = true; + removedAccessNode = otherIndexPaths.getAccess().getTensorVar().getName(); + cout << "removed access node " << removedAccessNode << endl; + + // mark producer accessed index variables + for (auto indexVar : sortedAllIndexes) { + if (indexVar != lastVisitedIndexVar) { // add everything except the last accessed index + std::cout << "producer vars: " << indexVar << std::endl; + producerVars.push_back(indexVar); + } + } + + for (auto indexVar : sortedAllIndexes) { + if (indexVar != lastVisitedIndexVar) { + if ( + find(resultIndexesVisited.begin(), resultIndexesVisited.end(), indexVar) + != resultIndexesVisited.end() || + find(indexesVisited.begin(), indexesVisited.end(), indexVar) + != indexesVisited.end() + ) { + modifiedResultIndexesAccessed.push_back(indexVar); + } + } + } + + // // get modified index for the intermediate calculated tensor expression + // for (unsigned long j=0; j& tensorPaths, + string& removedAccessNode, + vector& producerVars, + vector& consumerVars, + vector& modifiedResultIndexesAccessed, + vector& sortedAllIndexes) { + + std::cout << "check from front function execution\n"; + + const std::vector& resultIndexesVisited = resultTensorPath.getVariables(); + IndexVar firstVisitedIndexVar = resultIndexesVisited.front(); + + std::cout << "first fisited index variable: " << firstVisitedIndexVar << std::endl; + std::cout << "tensor path size: " << tensorPaths.size() << std::endl; + + bool onlyFirstTensorContainFirstIndexOfOutput = true; + bool fissionFromFront = false; + + // check from the front + for (long i=tensorPaths.size()-1; i>0; i--) { // change tensor paths to recursively use the functionality + std::cout << "i: " << i << std::endl; + const TensorPath& otherIndexPaths = tensorPaths.at(i); + const vector& indexesVisited = otherIndexPaths.getVariables(); + cout << "index paths: " << otherIndexPaths << endl; + + if (i != 0) { // check if other tensors also contain last index of output tensor + for (auto index : indexesVisited) { + cout << "checking " << index << " " << firstVisitedIndexVar << endl; + if (index == firstVisitedIndexVar) { + onlyFirstTensorContainFirstIndexOfOutput = false; + } + } + } + } + + + if (onlyFirstTensorContainFirstIndexOfOutput) { // last accessed tensorVariable + const TensorPath& otherIndexPaths = tensorPaths.front(); + const vector& indexesVisited = otherIndexPaths.getVariables(); + cout << "index paths: " << otherIndexPaths << endl; + + cout << "index variable maybe removed from the front\n"; + auto firstTensorFirstVisited = indexesVisited.front(); + cout << "first index first visited " << firstTensorFirstVisited << endl; + + if (firstTensorFirstVisited == firstVisitedIndexVar) { + cout << "we can diffuse from the front\n"; + fissionFromFront = true; + removedAccessNode = otherIndexPaths.getAccess().getTensorVar().getName(); + cout << "removed access node " << removedAccessNode << endl; + + // mark producer accessed index variables + for (auto indexVar : sortedAllIndexes) { + if (indexVar != firstVisitedIndexVar) { // add everything except the first accessed index + producerVars.emplace_back(indexVar); + } + } + + for (auto indexVar : sortedAllIndexes) { + if (indexVar != firstVisitedIndexVar) { + if ( + find(resultIndexesVisited.begin(), resultIndexesVisited.end(), indexVar) + != resultIndexesVisited.end() || + find(indexesVisited.begin(), indexesVisited.end(), indexVar) + != indexesVisited.end() + ) { + modifiedResultIndexesAccessed.push_back(indexVar); + } + } + } + + std::cout << "printing modifiedResultIndexesAccessed\n"; + for (auto& idx : modifiedResultIndexesAccessed) { + std::cout << "modifiedResultIndexesAccessed: " << idx << std::endl; + } + std::cout << "printed modifiedResultIndexesAccessed\n"; + + // get modified index for the intermediate calculated tensor expression + // for (unsigned long j=0; j forallParallelUnit; + map forallOutputRaceStrategy; + vector sortedIndexes; + Assignment innerBody; + + SortedIndexVars() {}; + + void visit(const ForallNode* node) { + Forall forallNode(node); + IndexVar i = forallNode.getIndexVar(); + std::cout << forallNode << std::endl; + + sortedIndexes.push_back(i); + forallParallelUnit[i] = forallNode.getParallelUnit(); + forallOutputRaceStrategy[i] = forallNode.getOutputRaceStrategy(); + + if (isa(forallNode.getStmt())) { + cout << "assignment node found: " << forallNode.getStmt() << endl;; + innerBody = to(forallNode.getStmt()); + return; // Only reorder first contiguous section of ForAlls + } + + IndexNotationVisitor::visit(node); + } + }; + + std::cout << "traversing through the index statement\n"; + SortedIndexVars sortedIndexVars; + stmt.accept(&sortedIndexVars); + std::cout << std::endl; + + struct IndexExprBuilder : public IndexNotationVisitor { + + using IndexNotationVisitor::visit; + vector accessLeftToRight; + map>> indexDimensionsMap; + + void visit(const AccessNode* node) { + Access accessNode(node); + std::cout << "access node: " << accessNode << std::endl; + accessLeftToRight.push_back(accessNode); + + TensorVar tensorVar = accessNode.getTensorVar(); + + for (unsigned long i=0; i < accessNode.getIndexVars().size(); i++) { + auto var = accessNode.getIndexVars()[i]; + + if (indexDimensionsMap.find(var) != indexDimensionsMap.end()) { + indexDimensionsMap[var].emplace_back( + pair(tensorVar.getType().getShape().getDimension(i), + tensorVar.getType())); + } + else { + indexDimensionsMap[var] = { + pair( + tensorVar.getType().getShape().getDimension(i), + tensorVar.getType()) + }; + } + } + + } + + }; + + IndexExpr rhsExpr = assignment.getRhs(); + Access lhsAccess = to(assignment.getLhs()); + std::cout << "right hand side expression: " << rhsExpr << std::endl; + IndexExprBuilder indexExprBuilder; + rhsExpr.accept(&indexExprBuilder); + TensorVar resultVar = lhsAccess.getTensorVar(); + + for (auto item : indexExprBuilder.indexDimensionsMap) { + auto indexVar = item.first; + cout << "var: " << indexVar << " "; + for (auto elem : item.second) { + cout << elem.first << " " << elem.second << " " ; + } + cout << endl; + } + + + // now I have the iteration graph + IterationGraph iterationGraph = IterationGraph::make(assignment); + std::cout << "/*******************************************/\n"; + std::cout << "/********** ITERATION GRAPH ****************/\n"; + std::cout << "/*******************************************/\n"; + std::cout << iterationGraph << std::endl; + + const TensorPath& resultTensorPath = iterationGraph.getResultTensorPath(); + const std::vector& tensorPaths = iterationGraph.getTensorPaths(); + + + string removedAccessNode; + vector producerVars; // producer accessed index variables + vector consumerVars; // consumer accessed index variables + vector fusedVars; + vector modifiedResultIndexesAccessed; + bool fissionFromBack = false; + if (side == "b") { + fissionFromBack = true; + } + + if (fissionFromBack) { + fissionFromBack = checkFromBack(resultTensorPath, tensorPaths, + removedAccessNode, producerVars, consumerVars, + modifiedResultIndexesAccessed, sortedIndexVars.sortedIndexes + ); + } + + bool fissionFromFront = false; + if (side == "f") { + fissionFromFront = true; + } + if (fissionFromBack == false && fissionFromFront) { + fissionFromFront = checkFromFront(resultTensorPath, tensorPaths, + removedAccessNode, producerVars, consumerVars, + modifiedResultIndexesAccessed, sortedIndexVars.sortedIndexes + ); + } + + if (!fissionFromBack && !fissionFromFront) { + cout << "fission operation cannot be performed from the back\n"; + return stmt; + } + + vector newAccessDims{}; + for (auto var : modifiedResultIndexesAccessed) { + auto item = indexExprBuilder.indexDimensionsMap[var]; + cout << "shared vars: " << var << endl; + newAccessDims.emplace_back(item[0].first); + } + TensorVar newAccessVar(resultVar.getName() + "_inner", + Type(resultVar.getType().getDataType(), newAccessDims)); + cout << "new inner assignment statement: " << modifiedResultIndexesAccessed << std::endl; + Access newResultAccess(newAccessVar, modifiedResultIndexesAccessed); + cout << "new access variable for iterative apply: " << newResultAccess << std::endl; + + if (fissionFromBack) { + std::cout << "fission from the back is possible\n"; + } + if (fissionFromFront) { + std::cout << "fission from the front is possible\n"; + } + + // // check from the front + // struct IndexExprSeparator : public IndexNotationVisitor { + + // using IndexNotationVisitor::visit; + // vector accessLeftToRight; + + // void visit(const MulNode* node) { + // Mul mulNode(node); + // IndexExpr lhs = mulNode.getA(); + // IndexExpr rhs = mulNode.getB(); + // std::cout << "access node: " << accessNode << std::endl; + // accessLeftToRight.push_back(accessNode); + // } + + // }; + + + cout << "\n\nProducer accessed index variables\n"; + auto it = producerVars.begin(); + for (; it != producerVars.end(); it++) { + cout << *it << endl; + } + cout << "\n\nConsumer accessed index variables\n"; + it = consumerVars.begin(); + for (; it != consumerVars.end(); it++) { + cout << *it << endl; + } + cout << endl << endl; + + // check common vars that can be fused + for (auto var : sortedIndexVars.sortedIndexes) { + if (find(producerVars.begin(), producerVars.end(), var) != producerVars.end() && + find(consumerVars.begin(), consumerVars.end(), var) != consumerVars.end()) { + fusedVars.emplace_back(var); + } + else { + break; + } + } + + for (auto& fv : fusedVars) { + std::cout << "fusable vars: " << fv << std::endl; + } + + vector sharedVars; + for (auto var : sortedIndexVars.sortedIndexes) { + if (find(fusedVars.begin(), fusedVars.end(), var) == fusedVars.end() && + find(producerVars.begin(), producerVars.end(), var) != producerVars.end() && + find(consumerVars.begin(), consumerVars.end(), var) != consumerVars.end() + ) { + sharedVars.emplace_back(var); + } + } + + for (auto& sv : sharedVars) { + std::cout << "shared vars: " << sv << std::endl; + } + + vector sharedDims{}; + for (auto var : sharedVars) { + auto item = indexExprBuilder.indexDimensionsMap[var]; + cout << "shared vars: " << var << endl; + sharedDims.emplace_back(item[0].first); + } + + + // get removing tensorvars and workspace dimension + const Type& type = resultTensorPath.getAccess().getTensorVar().getType(); + const Format& format = resultTensorPath.getAccess().getTensorVar().getFormat(); + TensorVar intermediateTensor("ws", type, format); + cout << intermediateTensor << endl; + + // TensorVar A("A", Type(), taco::dense); + TensorVar tempVar("t" + resultVar.getName(), + Type(resultVar.getType().getDataType(), sharedDims)); + cout << "tensor order: " << tempVar.getOrder() << endl; + cout << "tensor format: " << tempVar.getFormat() << endl; + cout << "format order: " << tempVar.getFormat().getOrder() << endl; + + // TensorVar* a = new TensorVar("A", Type()); + // TensorVar ws("ws", Type(type(), {jdim}) ); + + // get removing indexExpr and the rest of the indexExpr + Access workspace(tempVar, sharedVars); + std::cout << "workspace access tensor: " << workspace << std::endl; + + + + // construct producer expression right hand side + cout << "generating consumer expression\n"; + IndexExpr producerExpr; + int num_muls = 0; + for (Access accessNode : indexExprBuilder.accessLeftToRight) { + std::cout << "accessNodes: " << accessNode << endl; + if (removedAccessNode != accessNode.getTensorVar().getName()) { + if (producerExpr == NULL) { + std::cout << "index expression is null"; + producerExpr = accessNode; + std::cout << "producerExpr: " << producerExpr << std::endl; + } else { + num_muls++; + producerExpr = producerExpr * accessNode; + std::cout << "producerExpr: " << producerExpr << std::endl; + } + } + } + std::cout << producerExpr << std::endl; + Assignment producerAssignment(newResultAccess, + producerExpr); + std::cout << "new inner assignment statement: " << producerAssignment << std::endl; + Assignment producerInnerBody(workspace, + producerExpr, + sortedIndexVars.innerBody.getOperator() + ); + std::cout << "producerInnerBody: " << producerInnerBody << std::endl; + + // construct consumer expression right hand side + IndexExpr consumerExpr; + if (fissionFromBack) { + consumerExpr = workspace; + } + cout << "generating consumer expression: " << consumerExpr << std::endl; + for (Access accessNode : indexExprBuilder.accessLeftToRight) { + TensorVar tv = accessNode.getTensorVar(); + std::cout << "accessNodes: " << accessNode << endl; + if (removedAccessNode == accessNode.getTensorVar().getName()) { + if (consumerExpr == NULL) { + std::cout << "index expression is null"; + consumerExpr = accessNode; + std::cout << "consumerExpr: " << consumerExpr << std::endl; + } else { + consumerExpr = consumerExpr * accessNode; + std::cout << "consumerExpr: " << consumerExpr << std::endl; + } + } + } + if (fissionFromFront) { + consumerExpr = consumerExpr * workspace; + } + Assignment consumerInnerBody(lhsAccess, + consumerExpr, + sortedIndexVars.innerBody.getOperator() + ); + + cout << "Producer inner body: " << producerInnerBody << endl; + cout << "Consumer inner body: " << consumerInnerBody << endl; + + // rewrite indexstmt + // Reorder Foralls use a rewriter in case new nodes introduced outside of Forall + struct ProducerConsumerRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const vector& producerConsumerVars; + const vector& fusedVars; + IndexStmt innerBody; + const map forallParallelUnit; + const map forallOutputRaceStrategy; + + ProducerConsumerRewriter(const vector& producerConsumerVars, + const vector& fusedVars, IndexStmt innerBody, + const map forallParallelUnit, + const map forallOutputRaceStrategy) + : producerConsumerVars(producerConsumerVars), fusedVars(fusedVars), innerBody(innerBody), + forallParallelUnit(forallParallelUnit), forallOutputRaceStrategy(forallOutputRaceStrategy) { + } + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + cout << "going through var: " << i << endl; + + // first forall must be in collected variables + // taco_iassert(util::contains(producerVars, i)); + // std::cout << "\ninner body of the statement\n" << innerBody; + // // done in reverse order? + // for (auto it = sortedVars.rbegin(); it != sortedVars.rend(); ++it) { + // stmt = forall(*it, stmt, forallParallelUnit.at(*it), forallOutputRaceStrategy.at(*it), foralli.getUnrollFactor()); + // } + stmt = rewrite(foralli.getStmt()); + cout << "after rewrite statement: " << stmt << endl; + + // omit the index variables in the fusedVar list + if (find(fusedVars.begin(), fusedVars.end(), i) == fusedVars.end() && + find(producerConsumerVars.begin(), producerConsumerVars.end(), i) != producerConsumerVars.end()) { + stmt = forall(i, stmt, forallParallelUnit.at(i), forallOutputRaceStrategy.at(i), foralli.getUnrollFactor()); + } + } + + void visit (const AssignmentNode* node) { + cout << "assignment node: " << node << endl; + stmt = innerBody; + cout << "producerStmt: " << innerBody << endl; + cout << "stmt: " << stmt << endl; + } + + }; + ProducerConsumerRewriter producerRewriter(producerVars, fusedVars, + producerInnerBody, + sortedIndexVars.forallParallelUnit, + sortedIndexVars.forallOutputRaceStrategy); + IndexStmt producerStmt = producerRewriter.rewrite(stmt); + std::cout << "\nAfter Producer rewriter\n"; + std::cout << producerStmt << std::endl; + if (num_muls > 1) { + producerStmt = loopFusionOverFission(producerStmt, producerInnerBody, + side, iters-1); + } + + + ProducerConsumerRewriter consumerRewriter(consumerVars, fusedVars, + consumerInnerBody, + sortedIndexVars.forallParallelUnit, + sortedIndexVars.forallOutputRaceStrategy); + IndexStmt consumerStmt = consumerRewriter.rewrite(stmt); + std::cout << "\nAfter Consumer rewriter\n"; + std::cout << consumerStmt << std::endl; + + + struct CombineProducerConsumerRewriter : public IndexNotationRewriter { + + const vector& fusedVars; + IndexStmt consumerStmt; + IndexStmt producerStmt; + const map forallParallelUnit; + const map forallOutputRaceStrategy; + + CombineProducerConsumerRewriter(const vector& fusedVars, + IndexStmt producerStmt, IndexStmt consumerStmt, + const map forallParallelUnit, + const map forallOutputRaceStrategy) + : fusedVars(fusedVars), consumerStmt(consumerStmt), producerStmt(producerStmt), + forallParallelUnit(forallParallelUnit), + forallOutputRaceStrategy(forallOutputRaceStrategy) {} + + using IndexNotationRewriter::visit; + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + cout << "going through var: " << i << endl; + + // omit the index variables in the fusedVar list + if (find(fusedVars.begin(), fusedVars.end(), i) != fusedVars.end()) { + cout << "fused var in stmt\n"; + stmt = rewrite(foralli.getStmt()); + cout << "rewritten stmt: " << stmt << endl; + stmt = forall(i, stmt, forallParallelUnit.at(i), forallOutputRaceStrategy.at(i), foralli.getUnrollFactor()); + } + else { + cout << "fused var not in stmt\n"; + cout << "producerStmt: " << producerStmt << endl; + cout << "consumerStmt: " << consumerStmt << endl; + stmt = where(consumerStmt, producerStmt); + cout << "where stmt: " << stmt << endl; + } + + cout << "after rewrite statement: " << stmt << endl; + } + + }; + + CombineProducerConsumerRewriter combineRewriter(fusedVars, + producerStmt, consumerStmt, + sortedIndexVars.forallParallelUnit, + sortedIndexVars.forallOutputRaceStrategy); + IndexStmt combinedStmt = combineRewriter.rewrite(stmt); + std::cout << "\nAfter Combine rewriter\n"; + std::cout << combinedStmt << std::endl; + + + return combinedStmt; + +} + + IndexStmt reorderLoopsTopologically(IndexStmt stmt) { // Collect tensorLevelVars which stores the pairs of IndexVar and tensor // level that each tensor is accessed at diff --git a/src/tensor.cpp b/src/tensor.cpp index fab437ff1..bb695e2ed 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -10,6 +10,7 @@ #include #include +#include "../test/util.h" #include "taco/cuda.h" #include "taco/format.h" #include "taco/taco_tensor_t.h" @@ -825,6 +826,63 @@ void TensorBase::compute() { } } +void TensorBase::compute(std::ofstream& statfile, std::string& sofile) { + taco_uassert(!needsCompile()) << error::compute_without_compile; + // if (!needsCompute()) { + // return; + // } + setNeedsCompute(false); + // Sync operand tensors if needed. + auto operands = getTensors(getAssignment().getRhs()); + for (auto& operand : operands) { + // std::cout << "operand: " << operand.second << std::endl; + operand.second.syncValues(); + operand.second.removeDependentTensor(*this); + } + + auto arguments = packArguments(*this); + + taco::util::TimeResults timevalue; + bool time = true; + TOOL_BENCHMARK_TIMER2(this->content->module->callFuncPacked("compute", sofile, arguments.data()), + "\nkernel execution time: ", timevalue); + // this->content->module->callFuncPacked("compute", arguments.data()); + + if (content->assembleWhileCompute) { + setNeedsAssemble(false); + taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); + content->valuesSize = unpackTensorData(*tensorData, *this); + } +} + +void TensorBase::compute(std::ofstream& statfile) { + taco_uassert(!needsCompile()) << error::compute_without_compile; + // if (!needsCompute()) { + // return; + // } + setNeedsCompute(false); + // Sync operand tensors if needed. + auto operands = getTensors(getAssignment().getRhs()); + for (auto& operand : operands) { + operand.second.syncValues(); + operand.second.removeDependentTensor(*this); + } + + auto arguments = packArguments(*this); + + taco::util::TimeResults timevalue; + bool time = true; + TOOL_BENCHMARK_TIMER2(this->content->module->callFuncPacked("compute", arguments.data()), + "\nkernel execution time: ", timevalue); + // this->content->module->callFuncPacked("compute", arguments.data()); + + if (content->assembleWhileCompute) { + setNeedsAssemble(false); + taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); + content->valuesSize = unpackTensorData(*tensorData, *this); + } +} + void TensorBase::evaluate() { this->compile(); if (!getAssignment().getOperator().defined()) { diff --git a/tools/taco.cpp b/tools/taco.cpp index cd351a203..27d96fc4f 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -308,7 +308,7 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) { } } -static bool setSchedulingCommands(vector> scheduleCommands, parser::Parser& parser, IndexStmt& stmt) { +static bool setSchedulingCommands(vector> scheduleCommands, parser::Parser& parser, IndexStmt& stmt, Assignment assignment) { auto findVar = [&stmt](string name) { ProvenanceGraph graph(stmt); for (auto v : graph.getAllIndexVars()) { @@ -352,6 +352,16 @@ static bool setSchedulingCommands(vector> scheduleCommands, parse IndexVar fused(f); stmt = stmt.fuse(findVar(i), findVar(j), fused); + } else if (command == "loopfuse") { + taco_uassert(scheduleCommand.size() == 2) + << "'loopfuse' scheduling directive takes 2 parameters: fuse(b, 2)"; + std::string side = scheduleCommand[0]; + taco_uassert(side == "b" || side == "f") + << "first parameter must be either 'f' or 'b'"; + + int iters = std::stoi(scheduleCommand[1]); + + stmt = loopFusionOverFission(stmt, assignment, side, iters); } else if (command == "split") { taco_uassert(scheduleCommand.size() == 4) << "'split' scheduling directive takes 4 parameters: split(i, i1, i2, splitFactor)"; @@ -1112,12 +1122,18 @@ int main(int argc, char* argv[]) { taco_set_parallel_schedule(sched, chunkSize); taco_set_num_threads(nthreads); - IndexStmt stmt = - makeConcreteNotation(makeReductionNotation(tensor.getAssignment())); + + Assignment assignment = tensor.getAssignment(); + IndexStmt reducedStmt = makeReductionNotation(assignment); + IndexStmt stmt = makeConcreteNotation(reducedStmt); + std::cout << "concrete index statement: " << stmt << std::endl; + // IndexStmt stmt = + // makeConcreteNotation(makeReductionNotation(tensor.getAssignment())); stmt = reorderLoopsTopologically(stmt); + std::cout << "topologically reordered loops statement: " << stmt << std::endl; if (setSchedule) { - cuda |= setSchedulingCommands(scheduleCommands, parser, stmt); + cuda |= setSchedulingCommands(scheduleCommands, parser, stmt, assignment); } else { stmt = insertTemporaries(stmt); @@ -1355,7 +1371,7 @@ int main(int argc, char* argv[]) { } IterationGraph iterationGraph; - if (printIterationGraph) { + if (printIterationGraph) { // print iteration graph iterationGraph = IterationGraph::make(tensor.getAssignment()); } From 09f28dd0a6ed079ed94e7983788a472f97a11902 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Fri, 4 Mar 2022 14:47:54 -0500 Subject: [PATCH 2/2] task: add check compile for index stmts with where clause --- test/tests-indexstmt.cpp | 191 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 190 insertions(+), 1 deletion(-) diff --git a/test/tests-indexstmt.cpp b/test/tests-indexstmt.cpp index e2a972430..f14a7df58 100644 --- a/test/tests-indexstmt.cpp +++ b/test/tests-indexstmt.cpp @@ -1,10 +1,13 @@ #include "test.h" #include "test_tensors.h" #include "taco/tensor.h" +#include "taco/type.h" +#include "taco/index_notation/kernel.h" #include "taco/index_notation/index_notation.h" +#include "taco/index_notation/transformations.h" using namespace taco; -const IndexVar i("i"), j("j"), k("k"); +const IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); TEST(indexstmt, assignment) { Type t(type(), {3}); @@ -83,5 +86,191 @@ TEST(indexstmt, spmm) { ); } +TEST(indexstmt, sddmm) { + Type t(type(), {3,3}); + TensorVar A("A", t, {Sparse, Dense}); + TensorVar B("B", t, {Sparse, Dense}); + TensorVar C("C", t, {Dense, Dense}); + TensorVar w("w", Type(type(),{3}), Dense); + + // the below expression is the concrete index notation + // where (consumer, producer) + IndexStmt spmm = forall(i, + forall(k, + where(forall(j, A(i,j) = w(j)), + forall(j, w(j) += B(i,k)*C(k,j)) + ) + ) + ); + + // after adding scheduling transformations to this concrete-topologically sorted index stmt + // + + std::cout << spmm << std::endl; + spmm = reorderLoopsTopologically(spmm); + std::cout << "topologically reordered loops statement: " << spmm << std::endl; + + Kernel kernel = compile(spmm); + +} + + +TEST(indexstmt, sddmmPlusSpmm) { + + // Y(i,l) = B(i,j)*C(i,k)*D(k,j) * F(j,l); + // indexstmt order i, j, k, l + //topologically reordered loops statement: forall(i, forall(k, forall(j, forall(l, Y(i,l) += B(i,j) * C(i,k) * D(k,j) * F(j,l), NotParallel, IgnoreRaces), NotParallel, IgnoreRaces), NotParallel, IgnoreRaces), NotParallel, IgnoreRaces) + + Type t(type(), {3,3}); + TensorVar Y("Y", t, {Dense, Dense}); + TensorVar B("B", t, {Dense, Sparse}); + TensorVar C("C", t, {Dense, Dense}); + TensorVar D("D", t, {Dense, Dense}); + TensorVar E("E", t, {Dense, Dense}); + + // TensorVar A("A", Type(type(),{3}), ); + TensorVar A("A", Type()); + + IndexStmt fused1 = + forall(i, + forall(j, + forall(k, + forall(l, Y(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) + ) + ) + ); + + std::cout << "before topological sort" << fused1 << std::endl; + fused1 = reorderLoopsTopologically(fused1); + std::cout << "after topological sort" << fused1 << std::endl; + + Kernel kernel = compile(fused1); + + + IndexStmt fused2 = + forall(i, + forall(j, + where( + forall(l, Y(i,l) += A * E(j,l)), // consumer + forall(k, A += B(i,j)*C(i,k)*D(j,k)) // producer + ) + ) + ); + + Kernel kernel2 = compile(fused2); +} + +TEST(indexstmt, mttkrpPlusSpmm) { + + // ./bin/taco "A(i,m)=B(i,k,l)*C(k,j)*D(l,j)*E(j,m)" -f=A:dd:0,1 -f=B:sss:0,1,2 -f=C:dd:0,1 -f=D:dd:0,1 -f=E:dd:0,1 + + // i = 11, k = 5, l = 7, j = 8; + long unsigned int idim = 11, kdim = 5, ldim = 7, jdim = 8, mdim = 6; + + Type atype(type(), {idim, mdim}); + Type btype(type(), {idim, kdim, ldim}); + Type ctype(type(), {kdim, jdim}); + Type dtype(type(), {ldim, jdim}); + Type etype(type(), {jdim, mdim}); + + TensorVar A("A", atype, {Dense, Dense}); + TensorVar B("B", btype, {Sparse, Sparse, Sparse}); + TensorVar C("C", ctype, {Dense, Dense}); + TensorVar D("D", dtype, {Dense, Dense}); + TensorVar E("E", etype, {Dense, Dense}); + + TensorVar ws("ws", Type(type(), {jdim}) ); + + IndexStmt fused1 = + forall(i, + forall(k, + forall(l, + forall(j, + forall(m, A(i,m) += B(i,k,l) * C(k,j) * D(l,j) * E(j,m)) + ) + ) + ) + ); + + std::cout << "before topological sort" << fused1 << std::endl; + fused1 = reorderLoopsTopologically(fused1); + std::cout << "after topological sort" << fused1 << std::endl; + + Kernel kernel = compile(fused1); + + IndexStmt fused2 = + forall(i, + where( + forall(j, + forall(m, + A(i,m) += ws(j) * E(j,m) + ) + ) + , + forall(k, + forall(l, + forall(j, + ws(j) += B(i,k,l) * C(k,j) * D(l,j) + ) + ) + ) + ) + ); + + Kernel kernel2 = compile(fused2); + +} + +// ./bin/taco "y(i)=A(i,j)*B(j,k)*v(k)" -f=y:d:0 -f=A:dd:0,1 -f=B:dd:0,1 -f=v:d:0 +TEST(indexstmt, mmPlusSpmv) { + + // + + long unsigned int idim = 11, jdim = 8, kdim = 5; + + Type ytype(type(), {idim}); + Type atype(type(), {idim, jdim}); + Type btype(type(), {jdim, kdim}); + Type vtype(type(), {kdim}); + + TensorVar y("y", ytype, {Dense}); + TensorVar A("A", atype, {Dense, Dense}); + TensorVar B("B", btype, {Dense, Dense}); + TensorVar v("v", vtype, {Dense}); + + TensorVar ws("ws", Type(type(), {jdim}) ); + + IndexStmt fused1 = + forall(i, + forall(j, + forall(k, + forall(m, y(i) += A(i,j) * B(j,k) * v(k)) + ) + ) + ); + + std::cout << "before topological sort" << fused1 << std::endl; + fused1 = reorderLoopsTopologically(fused1); + std::cout << "after topological sort" << fused1 << std::endl; + + Kernel kernel = compile(fused1); + + IndexStmt fused2 = + where( + forall(i, + forall(j, + y(i) += A(i,j) * ws(j) + ) + ) + , + forall(j, + forall(k, + ws(j) += B(j,k) * v(k) + ) + ) + ); + + Kernel kernel2 = compile(fused2); +}