VK_EXT_shader_ocp_microscaling_types.proposal
This extension adds support for the floating-point types defined by the Open Compute Project (OCP) Microscaling (MX) formats specification. It does not provide any direct support for microscaled tensors or operations on microscaled data.
Problem Statement
With machine learning algorithms commonly being run on Vulkan, it has become desirable to support the newer data types that allow increased compute throughput and reduced memory bandwidth. Some of these types are used in the definition of so-called microscaled or block-scaled formats that use a scale factor that applies to a block of values (often 32) encoded using a low precision floating-point type. These schemes provide a good balance between the volume of data and dynamic range of the values represented and have gained popularity with both machine learning model authors and hardware providers because of the performance improvements that they enable.
Solution Space
Several levels of support could be considered:
- add basic support (type definition, conversions) for the new floating-point types used to define these new formats
- add support in dot product and/or matrix multiplication operations
- add support for loading/decoding of microscaled tensor data
This proposal focuses on 1.
Proposal
Add support for the floating-point types defined in the OCP Microscaling Format specification (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
These can be used in shaders to support consuming microscaled tensors and converting them to other data formats for which operations such as the matrix multiplication operations are already supported by extensions such as VK_KHR_cooperative_matrix.
Support for these types also enables further extensions to add direct support for matrix operations on microscaled tensors.
SPIR-V Changes
This extension adds new Floating Point Encoding values, enabling the operand to be specified when creating a floating point type:
| FP Encoding | Width(s) | Enabling Capabilities | |
|---|---|---|---|
4223 | Float6E2M3EXT The floating point type is encoded as an FP6 E2M3 type, as specified in the "OCP Microscaling Formats v1.0" specification. | 6 | Float6EXT |
4224 | Float6E3M2EXT The floating point type is encoded as an FP6 E3M2 type, as specified in the "OCP Microscaling Formats v1.0" specification. | 6 | Float6EXT |
4225 | Float4E2M1EXT The floating point type is encoded as an FP4 E2M1 type, as specified in the "OCP Microscaling Formats v1.0" specification. | 4 | Float4EXT |
4226 | Float8UnsignedE8M0EXT The floating point type is encoded as an unsigned FP8 E8M0 type, as specified in the "OCP Microscaling Formats v1.0" specification. | 8 | Float8UnsignedE8M0EXT |
4227 | MXInt8EXT The floating point type is encoded as an MXINT8 type, as specified in the "OCP Microscaling Formats v1.0" specification. | 8 | MXInt8EXT |
New capabilities enable both the declaration of the type and its use with cooperative matrix features:
| Capability | Implicitly Declares | |
|---|---|---|
4228 | Float6EXT Uses OpTypeFloat to specify types with the Float6E2M3EXT or Float6E3M2EXT FP Encoding and values of this type with a few instructions. | |
4229 | Float4EXT Uses OpTypeFloat to specify types with the Float4E2M1EXT FP Encoding and values of this type with a few instructions. | |
4230 | Float8UnsignedE8M0EXT Uses OpTypeFloat to specify types with the Float8UnsignedE8M0EXT FP Encoding and values of this type with a few instructions. | |
4231 | MXInt8EXT Uses OpTypeFloat to specify types with the MXInt8EXT FP Encoding and values of this type with a few instructions. | |
4232 | BitcastExtractEXT Uses OpBicastExtractEXT instruction. | |
API Changes
Features
This extension adds four features that map 1:1 to the capabilities exposed in that extension:
typedef struct VkPhysicalDeviceShaderOCPMicroscalingTypesFeaturesEXT {
VkStructureType sType;
void* pNext;
VkBool32 shaderFloat4;
VkBool32 shaderFloat6;
VkBool32 shaderFloat8UnsignedE8M0;
VkBool32 shaderMXInt8;
} VkPhysicalDeviceShaderOCPMicroscalingTypesFeaturesEXT;
shaderFloat4indicates support for theFloat4EXTandBitcastExtractEXTcapabilities.shaderFloat6indicates support for theFloat6EXTandBitcastExtractEXTcapabilities.shaderFloat8UnsignedE8M0indicates support for theFloat8UnsignedE8M0EXTcapability.shaderMXInt8indicates support for theMXInt8EXTcapability.
At least one of these features must be supported for this extension.
Issues
None.