VK_QCOM_cooperative_matrix_conversion.proposal

This document proposes a new extension that adds additional shader instructions for VK_KHR_cooperative_matrix.

Problem Statement

The baseline Cooperative Matrix extension achieves great performance boost for simple matrix multiplication operations when data is loaded to and from memory.

However, most use cases which leverage matrix multiplication hardware, such as Convolution and Large Language Models, require additional manipulation of input and output data which the opaque cooperative matrix objects do not support directly.

The cooperative matrix extension explicitly requires staging data through shared memory to perform these invocation-level manipulation operations. An extension is needed that allows implementations to create optimized data conversions between the invocation and subgroup scope without explicitly going through shared memory.

Solution Space

An explicit conversion API allows the application to bypass staging to shared memory.

Proposal

This proposal exposes support for the SPV_QCOM_cooperative_matrix_conversion SPIR-V extension. This extension provides the following:

  • Capabilities:
    • CooperativeMatrixConversionQCOM
  • Instructions:
    • OpBitCastArrayQCOM
    • OpCompositeConstructCoopMatQCOM
    • OpCompositeExtractCoopMatQCOM
    • OpExtractSubArrayQCOM

This extension is compatible with the GLSL extension GLSL_QCOM_cooperative_matrix_conversion.

Unlike GLSL_KHR_cooperative_matrix.txt, these instructions permit the per-invocation vectors to be located in private memory. Allowing workflows where shared memory staging is not required. These can also enable applications to write less vendor-dependent shaders, bypassing vendor required loading constraints and shuffling of data.

OpBitCastArrayQCOM

OpBitCastArrayQCOM can be used to perform a bit-cast conversion between compatible one-dimensional arrays. The GLSL equivalent is:

void bitcastQCOM(SrcTy srcArr[SrcLen], DstTy dstArr[DstLen]);
  • The size in bytes of the source and destination arrays must be the same
  • Valid types include int32_t, uint32_t, float32_t, float16_t

Examples:

uint32_t uvecA[8];
float    vecB[8];
bitcastQCOM(vecB, uvecA);

float16_t f16_vecB[16];
bitcastQCOM(f16_vecB, uvecA);

OpCompositeConstructCoopMatQCOM

OpCompositeConstructCoopMatQCOM can be used to construct a cooperative matrix cooperatively using a vector from each subgroup invocation. The GLSL equivalent is:

void vectorToCoopmatQCOM(SrcEltTy vec[SrcVecLen], coopmat<DstEltTy, gl_ScopeSubgroup, NumRows, NumCols, CoopMatUse> cm);
  • CoopMatUse determines the type of cooperative matrix and must be one of gl_MatrixUseA, gl_MatrixUseB, and gl_MatrixUseAccumulator

gl_MatrixUseA

When CoopMatUse type is gl_MatrixUseA, the per-invocation vectors fill in the rows of the matrix, where the vector for gl_SubgroupInvocationID == i fills in row i of the matrix. Vectors belonging to invocations in excess of NumRows are ignored and not copied.

The matrix NumRows must be a constant less than or equal to the gl_SubgroupSize and have a length in bytes equal to 32. The permitted types of the matrix are float32_t, float16_t, uint8_t, and int8_t.

Both NumRows and NumCols are permitted to be specialization constants.

The source and destination element types must match with the length of the source vector equal to the number of columns NumCols of the cooperative matrix.

However, a special case is allowed where the source type is uint32_t and length is 8, which defines an implicit bit cast to DstEltTy. Applications may get better performance by packing source vector elements into uint32_t over the float16 and 8-bit integer types.

coopmat<int8_t, gl_ScopeSubgroup, 64, 32, gl_MatrixUseA> si8_matA;
int8_t                                                   sivec[32];

// Load the 64 invocation vectors of int8[32] into each row of the 64x32 matrix
vectorToCoopmatQCOM(sivec, si8_matA);

// Alternative efficient packed upload case, with each element of uivec packed with 4 elements of matrix
uint32_t uivec[8];
vectorToCoopmatQCOM(uivec, si8_matA);

gl_MatrixUseB

When CoopMatUse type is gl_MatrixUseB, it operates similar to gl_MatrixUseA except that the per-invocation vectors fill in the columns of the matrix. The rules are the same with NumRows and NumCols swapped.

coopmat<float32_t, gl_ScopeSubgroup, 8, 64, gl_MatrixUseA> f32_matB;
float32_t                                                  fvec[8];

// Load the 64 invocation vectors of fp32[8] into each column of the 8x64 matrix
// Fp32 is already efficiently packed, so packing into uint32 is not necessary
vectorToCoopmatQCOM(fvec, f32_matB);

gl_MatrixUseAccumulator

When CoopMatUse type is gl_MatrixUseAccumulator, it operates similar to gl_MatrixUseA except that the constraints on the fields are different.

The rows of the matrix must be of length gl_SubgroupSize, gl_SubgroupSize/2, or gl_SubgroupSize/4 elements. The permitted types of the matrix are float32_t, float16_t, uint32_t, and int32_t.

The source and destination element types must match with the length of the source vector equal to the number of columns NumCols of the cooperative matrix.

However, a special case is allowed where the source type is uint32_t, destination type is float16_t, and length is NumCols/2, which defines an implicit bit cast. Applications may get better performance by packing source vector elements into uint32_t over the float16 type.

coopmat<float16_t, gl_ScopeSubgroup, 64, 32, gl_MatrixUseAccumulator> fp16_matAcc;
float16_t                                                             fpvec[32];

// Load the 64 invocation vectors of float16_t[32] into each row of the 64x32 matrix
vectorToCoopmatQCOM(fpvec, fp16_matAcc);

// Alternative efficient packed upload case, with each element of uivec packed with 2 elements of matrix
uint32_t uivec[16];
vectorToCoopmatQCOM(uivec, fp16_matAcc);

OpCompositeExtractCoopMatQCOM

OpCompositeExtractCoopMatQCOM can be used to concurrently extract rows or columns from a cooperative matrix into per-invocation vectors. The GLSL equivalent is:

void coopmatToVectorQCOM(coopmat<SrcEltTy, gl_ScopeSubgroup, NumRows, NumCols, CoopMatUse> cm, DstEltTy vec[DstVecLen]);

This performs the inverse operation of OpCompositeConstructCoopMatQCOM.

OpExtractSubArrayQCOM

OpExtractSubArrayQCOM can be used to slice an array into a sub-array. The GLSL equivalent is:

void extractSubArrayQCOM(EltTy srcArr[SrcLen], int start_index, EltTy dstArr[DstLen]);

Copies DstLen elements from srcArr starting at start_index into dstArr.

  • EltTy must be one of uint32_t, int32_t, float32_t, or float16_t
  • App must not specify a copy region out of bounds of the source array
  • SrcLen must be equal to an enumerated VkCooperativeMatrixPropertiesKHR::`NSize` with vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR
  • DstLen must be equal to an enumerated VkCooperativeMatrixPropertiesKHR::`KSize`
  • start_index must be a multiple an enumerated VkCooperativeMatrixPropertiesKHR::`KSize`

The NSize and KSize requirements do not need to be from the same enumerated VkCooperativeMatrixPropertiesKHR structure, but the KSize requirements for the DstLen and start_index must be the same KSize.

float32_t uvecAcc[32];
float32_t uvecA[8];
extractSubArrayQCOM(uvecAcc, 3, uvecA);
extractSubArrayQCOM(uvecAcc, 24, uvecA);

Features

The following feature structure is proposed:

typedef struct VkPhysicalDeviceCooperativeMatrixConversionFeaturesQCOM {
    VkStructureType    sType;
    void*              pNext;
    VkBool32           cooperativeMatrixConversion;
} VkPhysicalDeviceCooperativeMatrixConversionFeaturesQCOM;
  • cooperativeMatrixConversion allows shaders to define the CooperativeMatrixConversionQCOM capability

Examples

Convolution

Uses the implicit im2col technique with a NHWC tensor:

for (uint32_t step = 0; step < TOTAL_K; step += TILE_K)
{
    uint32_t subMatrixBStartInElements = col * FILTER_H * FILTER_W * strideBinElements + step;
    for (uint32_t filter_row = 0; filter_row < FILTER_H; filter_row++)
        for (uint32_t filter_col = 0; filter_col < FILTER_W; filter_col++)
        {
            // load matB input data using coop_mat extension
            coopmat<float, gl_ScopeSubgroup, TILE_K, TILE_N, gl_MatrixUseB> matB;
            coopMatLoad(matB, inputB, subMatrixBStartInElements, FILTER_H * FILTER_W * strideBinElements, LAYOUT_K_FIRST);

            // load vecA input data as vectors using regular vector load
            float vecA[TILE_K];
            uint32_t input_row = STRIDE * out_row + DILATION * (filter_row - FILTER_H/2);
            uint32_t input_col = STRIDE * out_col + DILATION * (filter_col - FILTER_W/2);
            for (int i=0; i<TILE_K; i++)
              vecA[i] = inputA[(input_row * INPUT_W + input_col) * strideAinElements + step + i];

            // zero fill vecA vector data for out of boundary cases
            if ((input_row < 0) || (input_row >= INPUT_H) || (input_col < 0) || (input_col >= INPUT_W))
              for (int i=0; i<TILE_K; i++) vecA[i] = float(0);

            // convert vecA to matA and perform matrix multiplication
            coopmat<float, gl_ScopeSubgroup, TILE_M, TILE_K, gl_MatrixUseA> matA;
            vectorToCoopmatQCOM(vecA, matA);
            matC = coopMatMulAdd(matA, matB, matC);

            subMatrixBStartInElements += strideBinElements;
        }
}

Neural Texture Decompression

coopmat<float16_t, gl_ScopeSubgroup, 64, 16, gl_MatrixUseA>           matA0, matA1;
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB>           matB0, matB1;
coopmat<float16_t, gl_ScopeSubgroup, 64, 16, gl_MatrixUseAccumulator> matC1, matC2;
matC1 = coopmat<float16_t, gl_ScopeSubgroup, 64, 16, gl_MatrixUseAccumulator>(0.0);
matC2 = coopmat<float16_t, gl_ScopeSubgroup, 64, 16, gl_MatrixUseAccumulator>(0.0);

// load matB0 and matB1 matrix input data using coopmat extension
coopMatLoad(matB0, inputB0, subMatrixBStartInElements, strideBinElements, LAYOUT_K_FIRST);
coopMatLoad(matB1, inputB1, subMatrixBStartInElements, strideBinElements, LAYOUT_K_FIRST);

// load vecA0 (input features) any way you like
float16_t vecA0[16];
for (int i=0; i<16; i++) vecA0[i] = inputA.x[inputCoord + i];

// convert vecA0 to matA0, then execute MatMul (first layer)
vectorToCoopmatQCOM(vecA0, matA0);
matC1 = coopMatMulAdd(matA0, matB0, matC1);

// convert matC1 to matA1, then execute MatMul (second layer)
float16_t vecA1[16];
coopmatToVectorQCOM(matC1, vecA1);
vectorToCoopmatQCOM(vecA1, matA1);
matC2 = coopMatMulAdd(matA1, matB1, matC2);

// convert matC2 to vecC2 (which holds the output features)
float16_t vecC2[16];
coopmatToVectorQCOM(matC2, vecC2);

Layer Merging (Flash Attention)

coopmat<float, gl_ScopeSubgroup, 64,  8, gl_MatrixUseA>           matA;
coopmat<float, gl_ScopeSubgroup, 8,  64, gl_MatrixUseB>           matB;
coopmat<float, gl_ScopeSubgroup, 64, 64, gl_MatrixUseAccumulator> matC;
coopmat<float, gl_ScopeSubgroup, 64,  8, gl_MatrixUseA>           matCtoA;
coopmat<float, gl_ScopeSubgroup, 8,  64, gl_MatrixUseB>           matD;
coopmat<float, gl_ScopeSubgroup, 64, 64, gl_MatrixUseAccumulator> matO;

for (uint stepN = 0; stepN < TOTAL_N; stepN += TILE_N64)
{
    for (uint stepK = 0; stepK < TOTAL_K; stepK + =TILE_K8) {
        coopMatLoad(matA, inputA, subMatrixAStart, STRIDE_A, LAYOUT_A);
        coopMatLoad(matB, inputB, subMatrixBStart, STRIDE_B, LAYOUT_B);
        matC = coopMatMulAdd(matA, matB, matC);
    }

    float vecC[64];
    coopmatToVectorQCOM(matC, vecC);
    vecC = online_softmax(vecC); // apply softmax on vecC using rowsum and rowmax reduction operations

    coopMatLoad(matD, inputD, subMatrixDStart + 0*STRIDE_D, STRIDE_D, LAYOUT_D);
    float vecCtoA[8];
    extractSubArrayQCOM(vecC,  0, vecCtoA); // first 8 elements of matC
    vectorToCoopmatQCOM(vecCtoA,  matCtoA);
    matO = coopMatMulAdd(matCtoA, matD, matO);
    ....
    coopMatLoad(matD, inputD, subMatrixDStart + 7*STRIDE_D, STRIDE_D, LAYOUT_D);
    extractSubArrayQCOM(vecC, 56, vecCtoA); // last 8 elements of matC
    vectorToCoopmatQCOM(vecCtoA, matCtoA);
    matO = coopMatMulAdd(matCtoUseA, matD, matO);
}

Dequantization

4-bit weight dequantization (llama.cpp q4):

 coopmat <float16_t, gl_ScopeSubgroup, TILE_M64, TILE_K16, gl_MatrixUseA>           matA;
 coopmat <float16_t, gl_ScopeSubgroup, TILE_K16, TILE_N64, gl_MatrixUseB>           matB;
 coopmat <float16_t, gl_ScopeSubgroup, TILE_M64, TILE_N64, gl_MatrixUseAccumulator> matC;

 float16_t vecAh_block[32], vecAh[TILE_K16], vecBh[TILE_K16];
 for (uint step = 0; step < TOTAL_K; step += TILE_K16)
 {
     for (int k=0; k<TILE_K16; k++)
         vecBh[k] = float16_t(data_bf[(pos_b + step) + (gl_SubgroupInvocationID * p.stride_b) + k]);

     if (step%32 == 0){
         const uint ib = (pos_a + step)/32 + gl_SubgroupInvocationID * p.stride_a/32;
         float d = float(data_a_packed16[ib].d);
         for (uint i = 0; i < 4; i++) {
             uint vui = uint(data_a_packed16[ib].qs[2 * i]) | (uint(data_a_packed16[ib].qs[2 * i + 1]) << 16);
             const vec4 v0 = (vec4(unpack8((vui >> 0) & 0x0F0F0F0F)) - 8.0f) * d;
             const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
             vecAh_block[4*i+ 0] = float16_t(v0.x);
             vecAh_block[4*i+ 1] = float16_t(v0.y);
             vecAh_block[4*i+ 2] = float16_t(v0.z);
             vecAh_block[4*i+ 3] = float16_t(v0.w);
             vecAh_block[4*i+16] = float16_t(v1.x);
             vecAh_block[4*i+17] = float16_t(v1.y);
             vecAh_block[4*i+18] = float16_t(v1.z);
             vecAh_block[4*i+19] = float16_t(v1.w);
         }
     }

     extractSubArrayQCOM(vecAh_block, step%32, vecAh); // extract first 16 or second 16 channels
     vectorToCoopmatQCOM(vecAh, matA);
     vectorToCoopmatQCOM(vecBh, matB);
     matC = coopMatMulAdd(matA, matB, matC);
}