Fix uninitialized variable in MIOpenDriver gemm and restore gemmfp16 for testing Upstream bug: https://github.com/ROCmSoftwarePlatform/MIOpen/issues/2505 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) printf("Usage: ./driver *base_arg* *other_args*\n"); printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], " "pool[fp16], lrn[fp16], " - "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], " + "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16,fp64]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" && arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" && arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" && - arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" && + arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" && arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" && arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version") { --- a/driver/gemm_driver.hpp +++ b/driver/gemm_driver.hpp @@ -207,6 +207,19 @@ int GemmDriver::GetandSetData() gemm_desc.strideB = gemm_desc.k * gemm_desc.n; gemm_desc.strideC = gemm_desc.m * gemm_desc.n; + if constexpr (std::is_same_v) + { + gemm_desc.dataType = miopenFloat; + } + else if constexpr (std::is_same_v) + { + gemm_desc.dataType = miopenHalf; + } + else + { + static_assert(!"unsupported type"); + } + return (0); } @@ -230,9 +243,9 @@ int GemmDriver::AllocateBuffersAndCopy() a = std::vector(a_sz); b = std::vector(b_sz); #if GEMM_DRIVER_DEBUG - c = std::vector(c_sz, 1.); + c = std::vector(c_sz, static_cast(1.)); #else - c = std::vector(c_sz, 0.); + c = std::vector(c_sz, static_cast(0.)); #endif chost = c; --- a/driver/main.cpp +++ b/driver/main.cpp @@ -125,11 +125,10 @@ int main(int argc, char* argv[]) { drv = new GemmDriver(); } -// TODO half is not supported in gemm -// else if(base_arg == "gemmfp16") -// { -// drv = new GemmDriver(); -// } + else if(base_arg == "gemmfp16") + { + drv = new GemmDriver(); + } #endif else if(base_arg == "bnorm") {