Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -659,7 +659,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { array_of(at::kHalf, at::kFloat, at::kBFloat16); constexpr auto less_than_sm80_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat); -#ifdef USE_ROCM +#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION) constexpr auto aotriton_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat, at::kBFloat16); #endif @@ -709,7 +709,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } } -#ifdef USE_ROCM +#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION) return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); #else auto dprop = at::cuda::getCurrentDeviceProperties();