VK_EXT_shader_ocp_microscaling_types.proposal

This extension adds support for the floating-point types defined by the Open Compute Project (OCP) Microscaling (MX) formats specification. It does not provide any direct support for microscaled tensors or operations on microscaled data.

Problem Statement

With machine learning algorithms commonly being run on Vulkan, it has become desirable to support the newer data types that allow increased compute throughput and reduced memory bandwidth. Some of these types are used in the definition of so-called microscaled or block-scaled formats that use a scale factor that applies to a block of values (often 32) encoded using a low precision floating-point type. These schemes provide a good balance between the volume of data and dynamic range of the values represented and have gained popularity with both machine learning model authors and hardware providers because of the performance improvements that they enable.

Solution Space

Several levels of support could be considered:

  1. add basic support (type definition, conversions) for the new floating-point types used to define these new formats
  2. add support in dot product and/or matrix multiplication operations
  3. add support for loading/decoding of microscaled tensor data

This proposal focuses on 1.

Proposal

Add support for the floating-point types defined in the OCP Microscaling Format specification (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).

These can be used in shaders to support consuming microscaled tensors and converting them to other data formats for which operations such as the matrix multiplication operations are already supported by extensions such as VK_KHR_cooperative_matrix.

Support for these types also enables further extensions to add direct support for matrix operations on microscaled tensors.

SPIR-V Changes

This extension adds new Floating Point Encoding values, enabling the operand to be specified when creating a floating point type:

FP EncodingWidth(s)Enabling Capabilities

4223

Float6E2M3EXT The floating point type is encoded as an FP6 E2M3 type, as specified in the "OCP Microscaling Formats v1.0" specification.

6

Float6EXT

4224

Float6E3M2EXT The floating point type is encoded as an FP6 E3M2 type, as specified in the "OCP Microscaling Formats v1.0" specification.

6

Float6EXT

4225

Float4E2M1EXT The floating point type is encoded as an FP4 E2M1 type, as specified in the "OCP Microscaling Formats v1.0" specification.

4

Float4EXT

4226

Float8UnsignedE8M0EXT The floating point type is encoded as an unsigned FP8 E8M0 type, as specified in the "OCP Microscaling Formats v1.0" specification.

8

Float8UnsignedE8M0EXT

4227

MXInt8EXT The floating point type is encoded as an MXINT8 type, as specified in the "OCP Microscaling Formats v1.0" specification.

8

MXInt8EXT

New capabilities enable both the declaration of the type and its use with cooperative matrix features:

CapabilityImplicitly Declares

4228

Float6EXT Uses OpTypeFloat to specify types with the Float6E2M3EXT or Float6E3M2EXT FP Encoding and values of this type with a few instructions.

4229

Float4EXT Uses OpTypeFloat to specify types with the Float4E2M1EXT FP Encoding and values of this type with a few instructions.

4230

Float8UnsignedE8M0EXT Uses OpTypeFloat to specify types with the Float8UnsignedE8M0EXT FP Encoding and values of this type with a few instructions.

4231

MXInt8EXT Uses OpTypeFloat to specify types with the MXInt8EXT FP Encoding and values of this type with a few instructions.

4232

BitcastExtractEXT Uses OpBicastExtractEXT instruction.

API Changes

Features

This extension adds four features that map 1:1 to the capabilities exposed in that extension:

typedef struct VkPhysicalDeviceShaderOCPMicroscalingTypesFeaturesEXT {
    VkStructureType    sType;
    void*              pNext;
    VkBool32           shaderFloat4;
    VkBool32           shaderFloat6;
    VkBool32           shaderFloat8UnsignedE8M0;
    VkBool32           shaderMXInt8;
} VkPhysicalDeviceShaderOCPMicroscalingTypesFeaturesEXT;
  • shaderFloat4 indicates support for the Float4EXT and BitcastExtractEXT capabilities.
  • shaderFloat6 indicates support for the Float6EXT and BitcastExtractEXT capabilities.
  • shaderFloat8UnsignedE8M0 indicates support for the Float8UnsignedE8M0EXT capability.
  • shaderMXInt8 indicates support for the MXInt8EXT capability.

At least one of these features must be supported for this extension.

Issues

None.