Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b5b20b8
update
AnnaTrainingG Dec 6, 2023
199b9d6
has data
AnnaTrainingG Dec 6, 2023
a582b3a
update
AnnaTrainingG Dec 6, 2023
78080dd
update
AnnaTrainingG Dec 6, 2023
0d6766e
update
AnnaTrainingG Dec 6, 2023
17b89c9
updat
AnnaTrainingG Dec 6, 2023
1355060
update
AnnaTrainingG Dec 6, 2023
d536119
update
AnnaTrainingG Dec 6, 2023
66fc8a7
update
AnnaTrainingG Dec 6, 2023
41ebd07
all
AnnaTrainingG Dec 6, 2023
ad614e0
update
AnnaTrainingG Dec 6, 2023
c8d003a
update
AnnaTrainingG Dec 6, 2023
559a479
update
AnnaTrainingG Dec 6, 2023
bd670ae
80 90
AnnaTrainingG Dec 6, 2023
e4b5006
error
AnnaTrainingG Dec 6, 2023
4f7f1f0
update build ok
AnnaTrainingG Dec 6, 2023
8c12f72
update
AnnaTrainingG Dec 6, 2023
7b257e8
update
AnnaTrainingG Dec 6, 2023
4fd33ea
updaet
AnnaTrainingG Dec 6, 2023
f03a1df
updaet
AnnaTrainingG Dec 6, 2023
d810108
upate
AnnaTrainingG Dec 6, 2023
48eb647
update
AnnaTrainingG Dec 6, 2023
7bb6f31
update
AnnaTrainingG Dec 6, 2023
58563ba
udpate
AnnaTrainingG Dec 6, 2023
e856a05
update
AnnaTrainingG Dec 6, 2023
256a3c6
update
AnnaTrainingG Dec 7, 2023
af386bf
update
AnnaTrainingG Dec 7, 2023
3aca223
Update
AnnaTrainingG Dec 7, 2023
06edc27
update
AnnaTrainingG Dec 8, 2023
45fcc53
update
AnnaTrainingG Dec 8, 2023
940a8ae
default
AnnaTrainingG Dec 8, 2023
6b6c7a8
update
AnnaTrainingG Dec 8, 2023
18ae756
update equal
AnnaTrainingG Dec 8, 2023
d926c09
for so
AnnaTrainingG Dec 10, 2023
a2714eb
Update CMakeLists.txt
AnnaTrainingG Dec 11, 2023
a61e35b
update fa1 mask
AnnaTrainingG Dec 11, 2023
600d748
Update for fa extends
AnnaTrainingG Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
cmake_minimum_required(VERSION 3.9 FATAL_ERROR)
project(flash-attention LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Git QUIET REQUIRED)

execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE GIT_SUBMOD_RESULT)

#cmake -DWITH_ADVANCED=ON
if (WITH_ADVANCED)
add_compile_definitions(PADDLE_WITH_ADVANCED)
endif()

add_definitions("-DFLASH_ATTN_WITH_TORCH=0")

set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass)
Expand Down Expand Up @@ -55,6 +62,7 @@ target_include_directories(flashattn PRIVATE
flash_attn
${CUTLASS_3_DIR}/include)

if (WITH_ADVANCED)
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
Expand All @@ -65,6 +73,12 @@ set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu
flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu
flash_attn_with_bias_and_mask/src/utils.cu)
else()
set(FA1_SOURCES_CU
flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu
flash_attn_with_bias_and_mask/src/cuda_utils.cu
flash_attn_with_bias_and_mask/src/utils.cu)
endif()

add_library(flashattn_with_bias_mask STATIC
flash_attn_with_bias_and_mask/
Expand All @@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask)

add_dependencies(flashattn flashattn_with_bias_mask)

set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")

if (NOT DEFINED NVCC_ARCH_BIN)
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.")
endif()

if (NVCC_ARCH_BIN STREQUAL "")
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.")
endif()
message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}")

STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN})

set(FA_GENCODE_OPTION "SHELL:")

foreach(arch ${FA_NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}")
Expand Down Expand Up @@ -131,7 +141,25 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
"${FA_GENCODE_OPTION}"
>)


INSTALL(TARGETS flashattn
LIBRARY DESTINATION "lib")

INSTALL(FILES capi/flash_attn.h DESTINATION "include")

if (WITH_ADVANCED)
set_target_properties(flashattn PROPERTIES
OUTPUT_NAME libflashattn_advanced
PREFIX ""
)
add_custom_target(build_whl
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
DEPENDS flashattn
COMMENT "Running build wheel"
)

add_custom_target(default_target DEPENDS build_whl)

set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target)
endif()
Loading