{"id":113179,"date":"2026-03-05T09:00:00","date_gmt":"2026-03-05T17:00:00","guid":{"rendered":"https:\/\/developer.nvidia.com\/blog\/?p=113179"},"modified":"2026-03-05T11:48:13","modified_gmt":"2026-03-05T19:48:13","slug":"tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile","status":"publish","type":"post","link":"https:\/\/developer.nvidia.com\/blog\/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile\/","title":{"rendered":"Tuning Flash Attention for Peak Performance in NVIDIA CUDA Tile"},"content":{"rendered":"\n<p>In this post, we dive into one of the most critical workloads in modern AI: <strong>Flash Attention<\/strong>, where you\u2019ll learn:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>How to implement Flash Attention using NVIDIA <\/strong><a href=\"https:\/\/github.com\/nvidia\/cutile-python\"><strong>cuTile<\/strong><\/a>. Walk through the complete code for a production-ready implementation.<\/li>\n\n\n\n<li><strong>The &#8220;trap and rescue&#8221; optimization journey<\/strong>. This case study shows how naive optimizations (like just increasing tile size) can backfire, and how to fix them.<\/li>\n\n\n\n<li><strong>Advanced techniques<\/strong> like FMA patterns, fast math, loop splitting, and adaptive tiling for maximum performance.<\/li>\n<\/ul>\n\n\n\n<p><strong>Environment requirements:<\/strong><\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>CUDA 13.1<\/strong> or higher<\/li>\n\n\n\n<li><strong>GPU architecture<\/strong>: Compute capability 8.X, 10.X, 11.X, 12.X (NVIDIA Ampere, NVIDIA Ada, NVIDIA Blackwell)<\/li>\n\n\n\n<li><strong>Python<\/strong>: 3.10 or higher<\/li>\n<\/ul>\n\n\n\n<p>See the <a href=\"https:\/\/docs.nvidia.com\/cuda\/cutile-python\/quickstart.html\">quickstart doc<\/a> for more information on installing cuTile Python.<\/p>\n\n\n\n<h2 id=\"what_is_attention\"  class=\"wp-block-heading\">What is attention?<a href=\"#what_is_attention\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>The attention mechanism is the computational heart of transformer models. Given a sequence of tokens, attention enables each token to &#8220;look at&#8221; every other token and decide how much to weigh their contributions. Mathematically, for input matrices Query (\\(Q\\)), Key (\\(K\\)), and Value (\\(V\\)), the output is:<\/p>\n\n\n\n<p>\\(O = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}}\\right)V\\)<\/p>\n\n\n\n<p>Where:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\\(Q \\text{ has shape } (N,d),\\ N \\text{ query tokens, each with dimension } d.\\)<\/li>\n\n\n\n<li>\\(K \\text{ has shape } (N,d),\\ N \\text{ key tokens.}\\)<\/li>\n\n\n\n<li>\\(V \\text{ has shape } (N,d),\\ N \\text{ value tokens.}\\)<\/li>\n\n\n\n<li>\\(\\text{The intermediate } QK^{T} \\text{ matrix has shape } (N,N), \\text{ is a problem.}\\)<\/li>\n<\/ul>\n\n\n\n<h3 id=\"the_memory_bandwidth_problem\"  class=\"wp-block-heading\">The memory bandwidth problem<a href=\"#the_memory_bandwidth_problem\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>For a sequence length of \\(N = 16,384\\) (common in modern LLMs), the attention matrix \\(QK^{T}\\) contains \\(N^2 = 268\\) million elements. In FP16, that&#8217;s 512 MB of intermediate storage per attention head, per batch item.<\/p>\n\n\n\n<p>Standard attention implementations:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>Compute the full \\(N \\times N\\) attention matrix and write it to global memory (slow)<\/li>\n\n\n\n<li>Apply softmax row-by-row<\/li>\n\n\n\n<li>Read the matrix back and multiply by \\(V\\) <\/li>\n<\/ol>\n\n\n\n<p>This approach is <strong>memory-bound<\/strong> as the GPU spends most of its time waiting for data to move between HBM and compute units, rather than computing.<\/p>\n\n\n\n<h3 id=\"how_flash_attention_solves_the_memory_bandwidth_problem\"  class=\"wp-block-heading\">How Flash Attention solves the memory bandwidth problem<a href=\"#how_flash_attention_solves_the_memory_bandwidth_problem\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><strong>Flash Attention<\/strong> (introduced by Dao et al., 2022) is an IO-aware algorithm that never materializes the full \\(N \\times N\\) matrix. Instead, it:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>Tiles the computation<\/strong>: Processes \\(Q, K, V\\) in small blocks that fit in fast on-chip SMEM<\/li>\n\n\n\n<li><strong>Uses online softmax<\/strong>: Computes softmax incrementally without needing the full row<\/li>\n\n\n\n<li><strong>Fuses operations<\/strong>: Combines the matrix multiply and softmax into a single kernel pass<\/li>\n<\/ol>\n\n\n\n<p>The result is a <strong>2-4x speedup<\/strong> and significant memory savings, enabling longer context lengths.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;69efb766da574&quot;}\" data-wp-interactive=\"core\/image\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"486\" height=\"348\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on-async--click=\"actions.showLightbox\" data-wp-on-async--load=\"callbacks.setButtonStyles\" data-wp-on-async-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation.webp\" alt=\"A tiled flash attention figure showing Q, K^T, V and O in HBM, being accumulated to Q, K, V, and O in SMEM.\" class=\"wp-image-113183\" srcset=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation.webp 486w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-161x115.png 161w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-300x215.png 300w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-419x300.png 419w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-126x90.png 126w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-362x259.png 362w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Tiled-computation-154x110.png 154w\" sizes=\"auto, (max-width: 486px) 100vw, 486px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on-async--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 1. Tiled Flash Attention computation<\/em><\/figcaption><\/figure><\/div>\n\n\n<h2 id=\"understanding_online_softmax\"  class=\"wp-block-heading\">Understanding online softmax<a href=\"#understanding_online_softmax\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>The key algorithmic insight of Flash Attention is the online softmax trick. The numerically stable <strong>safe softmax<\/strong> requires knowing the maximum value across the entire row before computing:<\/p>\n\n\n\n<p>\\(\\text{softmax}(x_i) = \\frac{e^{x_i &#8211; \\max(x)}}{\\sum_j e^{x_j &#8211; \\max(x)}}\\)<\/p>\n\n\n\n<p>But if we&#8217;re processing tiles, we don&#8217;t have access to the full row. Online softmax solves this by maintaining running statistics that can be updated incrementally.<\/p>\n\n\n\n<h3 id=\"the_online_softmax_algorithm\"  class=\"wp-block-heading\">The online softmax algorithm<a href=\"#the_online_softmax_algorithm\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>We maintain two running values for each row:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\\(m_i\\): The maximum value seen so far (for numerical stability)<\/li>\n\n\n\n<li>\\(l_i\\): The sum of exponentials seen so far (the softmax denominator)<\/li>\n<\/ul>\n\n\n\n<p>When we process a new tile with values \\(x_{new}\\):<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>Update the maximum<\/strong>: \\(m_{new} = \\max(m_i, \\max(x_{new}))\\)<\/li>\n\n\n\n<li><strong>Compute correction factor<\/strong>: \\(\\alpha = e^{m_i &#8211; m_{new}}\\) (rescales previous work)<\/li>\n\n\n\n<li><strong>Update the sum<\/strong>: \\(l_i = l_i \\cdot \\alpha + \\sum e^{x_{new} &#8211; m_{new}}\\)<\/li>\n\n\n\n<li><strong>Update the accumulator<\/strong>: \\(acc = acc \\cdot \\alpha + P_{new} \\cdot V_{tile}\\)<\/li>\n<\/ol>\n\n\n\n<p>\\(P_{new}\\) is the matrix of the attention weights, and \\(V_{tile}\\) is the value matrix tile, corresponding to the Key tile of the current iteration. At the end, we normalize: \\(O = acc \/ l_i\\)<\/p>\n\n\n\n<p>This enables us to compute an exact softmax without ever storing the full row.<\/p>\n\n\n\n<h2 id=\"causal_attention_and_grouped-query_attention\"  class=\"wp-block-heading\">Causal attention and grouped-query attention<a href=\"#causal_attention_and_grouped-query_attention\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Before diving into the implementation, let&#8217;s understand two important attention variants used in modern LLMs:<\/p>\n\n\n\n<h3 id=\"causal_attention&nbsp;\"  class=\"wp-block-heading\">Causal attention&nbsp;<a href=\"#causal_attention&nbsp;\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>In autoregressive language models like GPT, LLaMA, and Claude, each token can only attend to <strong>previous tokens<\/strong> in the sequence, not future ones. This prevents &#8220;cheating&#8221; during training, where the model looks ahead to predict the next word.<\/p>\n\n\n\n<p>Mathematically, we apply a <strong>triangular mask<\/strong> to the attention scores:<\/p>\n\n\n\n<p>\\(\\text{mask}_{ij} = \\begin{cases} 0 &amp; \\text{if } i \\geq j \\text{ (query position \u2265 key position)} \\ -\\infty &amp; \\text{if } i &lt; j \\text{ (future tokens)} \\end{cases}\\)<\/p>\n\n\n\n<p>The masked attention becomes:<\/p>\n\n\n\n<p>\\(O = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}} + \\text{mask}\\right)V\\)<\/p>\n\n\n\n<p>Adding \\(-\\infty\\) to future positions ensures they become zero after softmax, effectively blocking information flow from future tokens.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;69efb766db854&quot;}\" data-wp-interactive=\"core\/image\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"275\" height=\"275\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on-async--click=\"actions.showLightbox\" data-wp-on-async--load=\"callbacks.setButtonStyles\" data-wp-on-async-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention.webp\" alt=\"Causal attention mask matrix for 4 tokens showing how the upper triangle of the matrix is masked to 0, meaning that those values are not used in the computation.\u00a0\" class=\"wp-image-113186\" srcset=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention.webp 275w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-115x115.png 115w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-90x90.png 90w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-32x32.png 32w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-50x50.png 50w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-64x64.png 64w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-96x96.png 96w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-128x128.png 128w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-150x150.png 150w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-110x110.png 110w\" sizes=\"auto, (max-width: 275px) 100vw, 275px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on-async--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 2. Causal attention mask for four tokens<\/em><\/figcaption><\/figure><\/div>\n\n\n<p>With causal masking, roughly <strong>half the attention matrix is masked<\/strong> (the upper triangle). We can skip computing these masked tiles entirely, providing a 2x algorithmic speedup. This is crucial for the K-loop splitting optimization.<\/p>\n\n\n\n<h3 id=\"grouped-query_attention\"  class=\"wp-block-heading\">Grouped-query attention<a href=\"#grouped-query_attention\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Standard multi-head attention has separate \\(K,V\\) matrices for each attention head, leading to high memory usage:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Multi-head attention (MHA)<\/strong>: 32 query heads \u2192 32 K\/V heads (1:1 ratio)<\/li>\n\n\n\n<li><strong>Grouped-query attention (GQA)<\/strong>: 32 query heads \u2192 4 K\/V heads (8:1 ratio)<\/li>\n\n\n\n<li><strong>Multi-query attention (MQA)<\/strong>: 32 query heads \u2192 1 K\/V head (32:1 ratio)<\/li>\n<\/ul>\n\n\n\n<p>In GQA, multiple query heads <strong>share<\/strong> the same K\/V heads. For example, with 32 query heads and 4 K\/V heads:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Query heads 0-7 use K\/V head 0<\/li>\n\n\n\n<li>Query heads 8-15 use K\/V head 1<\/li>\n\n\n\n<li>Query heads 16-23 use K\/V head 2<\/li>\n\n\n\n<li>Query heads 24-31 use K\/V head 3<\/li>\n<\/ul>\n\n\n\n<p>This reduces K\/V cache size by <strong>8x<\/strong> during inference, critical for serving long-context models. Modern LLMs like LlamA 2, Llama 3, Mistral, and Qwen use GQA extensively.<\/p>\n\n\n\n<p><strong>When implementing in Flash Attention,<\/strong> each CUDA block computes attention for one query head, but loads the appropriate shared K\/V head:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nhead_idx = bid_y % num_heads              # Which query head (0-31)\nkv_head_idx = head_idx \/\/ query_group_size # Which K\/V head (0-3)\n<\/pre><\/div>\n\n\n<p>With a query group size of 8, query heads 0-7 all map to <code>kv_head_idx = 0<\/code>, sharing the same K\/V tiles in memory.<\/p>\n\n\n\n<h2 id=\"part_1_the_flash_attention_kernel_in_cuda_tile\"  class=\"wp-block-heading\">Part 1: The flash attention kernel in CUDA Tile<a href=\"#part_1_the_flash_attention_kernel_in_cuda_tile\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Let&#8217;s implement Flash Attention step-by-step. Our baseline uses small <strong>64\u00d764 tiles<\/strong> and straightforward code\u2014correct but not yet optimized.<\/p>\n\n\n\n<h3 id=\"1_defining_the_kernel_interface\"  class=\"wp-block-heading\">1. Defining the kernel interface<a href=\"#1_defining_the_kernel_interface\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>In cuTile, the <code>@ct.kernel<\/code> decorator marks a Python function as a GPU kernel. We pass compile-time constants using <code>ct.Constant[T]<\/code> type annotations:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nimport math\nimport cuda.tile as ct\n\n# Type aliases for compile-time constants\nConstInt = ct.Constant&#x5B;int]\nConstBool = ct.Constant&#x5B;bool]\n\n# Conversion factor: we use exp2 instead of exp for efficiency\nINV_LOG_2 = 1.0 \/ math.log(2)\n\n@ct.kernel()\ndef fmha_kernel(\n    Q, K, V, Out,              # Input\/output tensors\n    qk_scale: float,           # Scale factor (1\/sqrt(d))\n    input_pos: int,            # Position offset for causal masking\n    TILE_D: ConstInt,          # Head dimension (for example, 128)\n    H: ConstInt,               # Number of attention heads\n    TILE_M: ConstInt,          # Tile size for Q dimension (for example, 64)\n    TILE_N: ConstInt,          # Tile size for K\/V dimension (for example, 64)\n    QUERY_GROUP_SIZE: ConstInt,# For Grouped Query Attention (GQA)\n    CAUSAL: ConstBool,         # Whether to apply causal mask\n    EVEN_K: ConstBool,         # Whether K length is divisible by TILE_N\n):\n<\/pre><\/div>\n\n\n<h3 id=\"2_block_id_mapping\"  class=\"wp-block-heading\">2. Block ID mapping<a href=\"#2_block_id_mapping\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Each CUDA block computes one tile of the output. Using <a href=\"https:\/\/docs.nvidia.com\/cuda\/cutile-python\/generated\/cuda.tile.bid.html#cuda.tile.bid\"><code>ct.bid<\/code><\/a> , we map the 2D grid to batch\/head indices:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Get block indices\n    bid_x = ct.bid(0)  # Which tile along the sequence dimension\n    bid_y = ct.bid(1)  # Which batch-head combination\n    \n    # Decode batch and head from flattened index\n    batch_idx = bid_y \/\/ H\n    head_idx = bid_y % H\n    \n    # For Grouped Query Attention: multiple Q heads share one K\/V head\n    off_kv_h = head_idx \/\/ QUERY_GROUP_SIZE\n<\/pre><\/div>\n\n\n<h3 id=\"3_initializing_accumulators\"  class=\"wp-block-heading\">3. Initializing accumulators<a href=\"#3_initializing_accumulators\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Before the main loop, we initialize the online softmax state and output accumulator:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Convert scale for base-2 exponential (faster than natural exp)\n    qk_scale = qk_scale * INV_LOG_2\n    \n    # Create position indices for this tile\n    offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)\n    offs_m += input_pos\n    offs_m = offs_m&#x5B;:, None]  # Shape: &#x5B;TILE_M, 1]\n    \n    offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)\n    offs_n_tile = offs_n_tile&#x5B;None, :]  # Shape: &#x5B;1, TILE_N]\n    \n    # Online softmax state (float32 for numerical stability)\n    m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)  # Running max\n    l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)        # Running sum\n    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)   # Output accumulator\n<\/pre><\/div>\n\n\n<p>We use <code>float32<\/code> for accumulators, even when inputs are float16 to maintain numerical precision during the iterative softmax computation.<\/p>\n\n\n\n<h3 id=\"4_loading_the_query_tile\"  class=\"wp-block-heading\">4. Loading the query tile<a href=\"#4_loading_the_query_tile\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The query tile is loaded once and reused across all K\/V iterations:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n    # Load Q tile: shape &#x5B;1, 1, TILE_M, TILE_D] -&gt; &#x5B;TILE_M, TILE_D]\n    q = ct.load(\n        Q, \n        index=(batch_idx, head_idx, bid_x, 0), \n        shape=(1, 1, TILE_M, TILE_D)\n    ).reshape((TILE_M, TILE_D))\n<\/pre><\/div>\n\n\n<p>The <a href=\"https:\/\/docs.nvidia.com\/cuda\/cutile-python\/generated\/cuda.tile.load.html#cuda.tile.load\"><code>ct.load<\/code><\/a> function handles boundary conditions automatically when the tile extends past the tensor edge.<\/p>\n\n\n\n<h3 id=\"5_the_main_loop_over_kv_tiles\"  class=\"wp-block-heading\">5. The main loop over K\/V tiles<a href=\"#5_the_main_loop_over_kv_tiles\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>This is the heart of Flash Attention. We iterate over K\/V tiles:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n   # Calculate loop bounds\n    m_end = input_pos + (bid_x + 1) * TILE_M\n    k_seqlen = K.shape&#x5B;2]\n    \n    if CAUSAL:\n        # For causal attention, stop early (future tokens are masked)\n        Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)\n    else:\n        Tc = ct.cdiv(k_seqlen, TILE_N)\n    \n    for j in range(0, Tc):\n        # --- Step A: Load Key tile and compute QK^T ---\n        k = ct.load(\n            K,\n            index=(batch_idx, off_kv_h, 0, j),\n            shape=(1, 1, TILE_D, TILE_N),\n            order=(0, 1, 3, 2),  # Transpose for correct layout\n            latency=2            # Hint for memory prefetching\n        ).reshape((TILE_D, TILE_N))\n        \n        # Matrix multiply: Q @ K^T\n        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)\n        qk = ct.mma(q, k, qk)  # Uses Tensor Cores automatically\n<\/pre><\/div>\n\n\n<p>The <code>order=(0,1,3,2)<\/code> in the parameter tells <a href=\"https:\/\/docs.nvidia.com\/cuda\/cutile-python\/generated\/cuda.tile.load.html#cuda.tile.load\">cuTile load<\/a> operation to use K transposed, and <code>latency=2<\/code> hints that we can tolerate some latency (enabling better pipelining). Then we use the <code>ct.mma=(q, k, k,qk)<\/code> to perform the <a href=\"https:\/\/docs.nvidia.com\/cuda\/cutile-python\/generated\/cuda.tile.mma.html\">cuTile matrix multiply-accumulate<\/a>.<\/p>\n\n\n\n<h3 id=\"6_applying_the_causal_mask\"  class=\"wp-block-heading\">6. Applying the causal mask<a href=\"#6_applying_the_causal_mask\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>For autoregressive models (GPT, Llama, etc.), each token can only attend to previous tokens:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# --- Step B: Apply causal masking ---\n        if CAUSAL or not EVEN_K:\n            offs_n = j * TILE_N + offs_n_tile\n            mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)\n            \n            # Boundary mask (for non-divisible sequence lengths)\n            if not EVEN_K:\n                mask = mask &amp; (offs_n &lt; k_seqlen)\n            \n            # Causal mask: query position &gt;= key position\n            if CAUSAL:\n                mask = mask &amp; (offs_m &gt;= offs_n)\n            \n            # Convert to additive mask: True-&gt;0, False-&gt;-inf\n            mask = ct.where(mask, 0.0, -math.inf)\n            qk += mask\n<\/pre><\/div>\n\n\n<p>Adding <code>-inf<\/code> to masked positions ensures they become zero after softmax.<\/p>\n\n\n\n<h3 id=\"7_online_softmax_update\"  class=\"wp-block-heading\">7. Online softmax update<a href=\"#7_online_softmax_update\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Now we update our running softmax statistics:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n   # --- Step C: Online softmax ---\n        # Find max in current tile\n        qk_max = ct.max(qk, axis=-1, keepdims=True)\n        qk_max_scaled = qk_max * qk_scale\n        \n        # Update running maximum\n        m_ij = max(m_i, qk_max_scaled)\n        \n        # Scale QK scores\n        qk = qk * qk_scale\n        qk = qk - m_ij\n        \n        # Compute attention weights (using exp2 for speed)\n        p = ct.exp2(qk)\n        \n        # Update running sum\n        l_ij = ct.sum(p, axis=-1, keepdims=True)\n        alpha = ct.exp2(m_i - m_ij)  # Correction factor\n        l_i = l_i * alpha\n        l_i = l_i + l_ij\n        \n        # Rescale previous accumulator\n        acc = acc * alpha\n<\/pre><\/div>\n\n\n<h3 id=\"8_accumulating_the_output\"  class=\"wp-block-heading\">8. Accumulating the output<a href=\"#8_accumulating_the_output\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Finally, we load the Value tile and accumulate:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# --- Step D: Load V and accumulate ---\n        v = ct.load(\n            V,\n            index=(batch_idx, off_kv_h, j, 0),\n            shape=(1, 1, TILE_N, TILE_D),\n            latency=4\n        ).reshape((TILE_N, TILE_D))\n        \n        # Cast attention weights back to input dtype for Tensor Core MMA\n        p = p.astype(Q.dtype)\n        \n        # Accumulate: acc += P @ V\n        acc = ct.mma(p, v, acc)\n        \n        # Update max for next iteration\n        m_i = m_ij\n<\/pre><\/div>\n\n\n<h3 id=\"9_final_normalization_and_store\"  class=\"wp-block-heading\">9. Final normalization and store<a href=\"#9_final_normalization_and_store\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>After processing all tiles, we normalize by the total sum and write the result:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n   # --- Final: Normalize and store ---\n    acc = ct.truediv(acc, l_i)\n    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)\n    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)\n<\/pre><\/div>\n\n\n<h2 id=\"launching_the_kernel_host-side_code\"  class=\"wp-block-heading\">Launching the kernel: Host-side code<a href=\"#launching_the_kernel_host-side_code\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Now let&#8217;s look at the host-side code that launches the kernel:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nimport torch\nfrom math import ceil\n\ndef tile_fmha(q, k, v, sm_scale=None, is_causal=True):\n    &quot;&quot;&quot;\n    Launch the Flash Attention kernel.\n    \n    Args:\n        q: Query tensor, shape &#x5B;batch, heads, seq_len, head_dim]\n        k: Key tensor, shape &#x5B;batch, kv_heads, seq_len, head_dim]\n        v: Value tensor, shape &#x5B;batch, kv_heads, seq_len, head_dim]\n        sm_scale: Softmax scale (default: 1\/sqrt(head_dim))\n        is_causal: Whether to apply causal masking\n    \n    Returns:\n        Output tensor, same shape as q\n    &quot;&quot;&quot;\n    if sm_scale is None:\n        sm_scale = 1.0 \/ math.sqrt(q.size(-1))\n    \n    batch_size, num_heads, seq_len, head_dim = q.shape\n    _, num_kv_heads, _, _ = k.shape\n    \n    # Calculate query group size for GQA\n    query_group_size = num_heads \/\/ num_kv_heads\n    \n    # Ensure contiguous memory layout\n    q = q.contiguous()\n    k = k.contiguous()\n    v = v.contiguous()\n    \n    # Allocate output\n    o = torch.empty_like(q)\n    \n    # Choose tile sizes (we&#039;ll optimize this later!)\n    TILE_M, TILE_N = 64, 64\n    \n    # Calculate grid dimensions\n    grid_x = ceil(seq_len \/ TILE_M)  # Number of tiles along sequence\n    grid_y = batch_size * num_heads  # One block per batch-head pair\n    grid = (grid_x, grid_y, 1)\n    \n    # Check if K length is evenly divisible\n    EVEN_K = (k.shape&#x5B;2] % TILE_N) == 0\n    \n    # Launch kernel\n    ct.launch(\n        torch.cuda.current_stream(),\n        grid,\n        fmha_kernel,\n        (q, k, v, o, sm_scale, 0, head_dim, num_heads,\n         TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)\n    )\n    \n    return o\n<\/pre><\/div>\n\n\n<p>This baseline with <strong>64\u00d764 tiles<\/strong> works correctly. But can we make it faster? Let&#8217;s find out.<\/p>\n\n\n\n<h2 id=\"part_2_the_&#8220;trap_and_rescue&#8221;_optimization_journey\"  class=\"wp-block-heading\">Part 2: The &#8220;trap and rescue&#8221; optimization journey<a href=\"#part_2_the_&#8220;trap_and_rescue&#8221;_optimization_journey\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>We benchmark on the following configuration:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Hardware<\/strong>: NVIDIA B200<\/li>\n\n\n\n<li><strong>Batch<\/strong>: 4, <strong>Heads<\/strong>: 32, <strong>Head dimension<\/strong>: 128<\/li>\n\n\n\n<li><strong>Attention<\/strong>: Causal, <strong>Dtype<\/strong>: FP16<\/li>\n\n\n\n<li><strong>Sequence lengths<\/strong>: 1024, 2048, 4096, 8192, 16384<\/li>\n<\/ul>\n\n\n\n<p>To interpret each step, we use Nsight Compute with a minimal section set:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><code>LaunchStats<\/code><\/li>\n\n\n\n<li><code>Occupancy<\/code><\/li>\n\n\n\n<li><code>SpeedOfLight<\/code><\/li>\n\n\n\n<li><code>ComputeWorkloadAnalysis<\/code><\/li>\n\n\n\n<li><code>MemoryWorkloadAnalysis<\/code><\/li>\n<\/ul>\n\n\n\n<h3 id=\"baseline_performance\"  class=\"wp-block-heading\">Baseline performance<a href=\"#baseline_performance\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Throughput (TFLOPS)<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>330<\/td><\/tr><tr><td>2,048<\/td><td>441<\/td><\/tr><tr><td>4,096<\/td><td>511<\/td><\/tr><tr><td>8,192<\/td><td>546<\/td><\/tr><tr><td>16,384<\/td><td>566<\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 1. Baseline performance without any specific optimizations<\/em><\/figcaption><\/figure>\n\n\n\n<p>This is our starting point with 64\u00d764 tiles and no optimizations.<\/p>\n\n\n\n<p><strong>NCU insight (SeqLen=1024, B200)<\/strong>:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Registers\/thread: 128<\/li>\n\n\n\n<li>Theoretical\/achieved occupancy: 25% \/ 19.8%<\/li>\n\n\n\n<li>Compute (SM) throughput: 37.8%<\/li>\n\n\n\n<li>Memory throughput: 19.7%<\/li>\n\n\n\n<li>Grid size: 2,048<\/li>\n<\/ul>\n\n\n\n<h2 id=\"1_the_trap_of_larger_tiles\"  class=\"wp-block-heading\">1. The trap of larger tiles<a href=\"#1_the_trap_of_larger_tiles\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>A common intuition in GPU programming is <strong>&#8220;bigger tiles = better performance.&#8221;<\/strong> Larger tiles:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Amortize memory access overhead.<\/li>\n\n\n\n<li>Improve L2 cache utilization.<\/li>\n\n\n\n<li>Reduce kernel launch overhead per element.<\/li>\n<\/ul>\n\n\n\n<p>So, let&#8217;s increase our tile size from <strong>64\u00d764<\/strong> to <strong>256\u00d7128<\/strong>:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nTILE_M, TILE_N = 256, 128  # Was 64, 64\n<\/pre><\/div>\n\n\n<p><strong>The expected<\/strong> is better memory bandwidth utilization \u2192 faster performance. However, the <strong>result in TFLOPS<\/strong> are:<\/p>\n\n\n\n<figure class=\"wp-block-table\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Baseline (64\u00d764)<\/strong><\/th><th><strong>Larger tiles (256\u00d7128)<\/strong><\/th><th><strong>Performance Degradation<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>330<\/td><td>187<\/td><td><strong>-43%<\/strong><\/td><\/tr><tr><td>2,048<\/td><td>441<\/td><td>268<\/td><td><strong>-39%<\/strong><\/td><\/tr><tr><td>4,096<\/td><td>511<\/td><td>347<\/td><td><strong>-32%<\/strong><\/td><\/tr><tr><td>8,192<\/td><td>546<\/td><td>415<\/td><td><strong>-24%<\/strong><\/td><\/tr><tr><td>16,384<\/td><td>566<\/td><td>463<\/td><td><strong>-18%<\/strong><\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 2. Baseline performance compared to performance with larger tile sizes, showing degradation when using larger tile sizes<\/em><\/figcaption><\/figure>\n\n\n\n<p>Performance degraded by <strong>18-43%<\/strong> across all sequence lengths. <strong>This is the trap, where<\/strong> large tiles make performance <em>worse<\/em>.<\/p>\n\n\n\n<p><strong>Why does this happen?<\/strong><\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>Compute bottleneck<\/strong>: With more elements per tile, inefficient operations (separate mul\/add, precise math) become the bottleneck.<\/li>\n\n\n\n<li><strong>Instruction overhead<\/strong>: More work per tile means more instructions before the next memory operation.<\/li>\n<\/ol>\n\n\n\n<p><strong>Lesson<\/strong>: Tile size and compute efficiency are interdependent. Large tiles only help if the computation is efficient enough to keep up.<\/p>\n\n\n\n<p><strong>NCU insight (SeqLen=1,024, NVIDIA B200)<\/strong>:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Registers\/thread jump to 168 (+31%), reducing theoretical occupancy to 18.75%<\/li>\n\n\n\n<li>Achieved occupancy drops to 16.5%<\/li>\n\n\n\n<li>Compute throughput collapses to 17.4% (the trap)<\/li>\n\n\n\n<li>Memory throughput falls to 7.4%<\/li>\n\n\n\n<li>Grid size shrinks to 512 (fewer blocks from larger tiles)<\/li>\n<\/ul>\n\n\n\n<h2 id=\"2_the_rescue_with_fast_math&nbsp;\"  class=\"wp-block-heading\">2. The rescue with fast math&nbsp;<a href=\"#2_the_rescue_with_fast_math&nbsp;\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>One of the bottlenecks is special functions: exp2 (exponential) and truediv (division). By default, these are IEEE-754 precise\u2014highly accurate, but slow.<\/p>\n\n\n\n<p>For deep learning, we can trade a tiny bit of precision for massive speedups:<\/p>\n\n\n\n<p><strong>Before<\/strong> (precise operations):<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\np = ct.exp2(qk)\nalpha = ct.exp2(m_i - m_ij)\nacc = ct.truediv(acc, l_i)\n<\/pre><\/div>\n\n\n<p><strong>After<\/strong> (fast math):<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\np = ct.exp2(qk, flush_to_zero=True)\nalpha = ct.exp2(m_i - m_ij, flush_to_zero=True)\nacc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)\n<\/pre><\/div>\n\n\n<p><strong>What these flags do<\/strong>:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><code>flush_to_zero=True<\/code>: Denormal numbers (extremely small values near zero) become exactly zero. This avoids slow microcode paths on the GPU.<\/li>\n\n\n\n<li><code>rounding_mode=RMd.APPROX<\/code>: Skips iterative refinement after initial hardware approximation.<\/li>\n<\/ul>\n\n\n\n<p>With fast math, we&#8217;ve &#8220;rescued&#8221; the large tiles, and the <strong>results in TFLOPS are:<\/strong><\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Larger tiles (trap)<\/strong><\/th><th><strong>Fast math (rescue)<\/strong><\/th><th><strong>Improvement<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>187<\/td><td>322<\/td><td><strong>+72%<\/strong><\/td><\/tr><tr><td>2,048<\/td><td>268<\/td><td>436<\/td><td><strong>+63%<\/strong><\/td><\/tr><tr><td>4,096<\/td><td>347<\/td><td>524<\/td><td><strong>+51%<\/strong><\/td><\/tr><tr><td>8,192<\/td><td>415<\/td><td>585<\/td><td><strong>+41%<\/strong><\/td><\/tr><tr><td>16,384<\/td><td>463<\/td><td>620<\/td><td><strong>+34%<\/strong><\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 3. Performance improvement when using two fast math optimizations<\/em><\/figcaption><\/figure>\n\n\n\n<p>We now match or exceed the small-tile baseline, with <strong>10-20% gains<\/strong> for longer sequences.<\/p>\n\n\n\n<p><strong>NCU insight (SeqLen=1,024, NVIDIA B200)<\/strong>:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Registers\/thread: 168 (unchanged)<\/li>\n\n\n\n<li>Theoretical\/achieved occupancy: 18.75% \/ 16.6% (unchanged)<\/li>\n\n\n\n<li>Compute throughput rebounds to 24.0%<\/li>\n\n\n\n<li>Memory throughput improves to 12.9%<\/li>\n<\/ul>\n\n\n\n<h2 id=\"3_k-loop_split\"  class=\"wp-block-heading\">3. K-loop split<a href=\"#3_k-loop_split\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>For <strong>causal attention<\/strong>, we apply a triangular mask: each query can only attend to keys at earlier positions. In our baseline, we check <code>if CAUSAL: mask<\/code>&#8230; on <em>every<\/em> loop iteration.<\/p>\n\n\n\n<p>But think about it: for a query tile at position 1000, most key tiles (0-900) need <strong>no masking at all<\/strong>. Only tiles near the diagonal need the mask. And tiles beyond the query position are <strong>completely masked<\/strong> (we can skip them entirely).<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;69efb766de116&quot;}\" data-wp-interactive=\"core\/image\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"474\" height=\"285\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on-async--click=\"actions.showLightbox\" data-wp-on-async--load=\"callbacks.setButtonStyles\" data-wp-on-async-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix.webp\" alt=\"Q by K tiled causal attention matrix showing 8 tiles per side and showing how the lower triangle is computed. The diagonal is partially computed, and the upper triangle is skipped.\" class=\"wp-image-113192\" srcset=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix.webp 474w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix-179x108.png 179w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix-300x180.png 300w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix-150x90.png 150w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix-362x218.png 362w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/Casual-Attention-Matrix-183x110.png 183w\" sizes=\"auto, (max-width: 474px) 100vw, 474px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on-async--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 3. Tiled causal attention matrix (8 tiles per side)&nbsp;<\/em><\/figcaption><\/figure><\/div>\n\n\n<p><strong>The optimization<\/strong> splits the loop into phases:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n# Calculate where masking starts being necessary\nmask_start = (input_pos + bid_x * TILE_M) \/\/ TILE_N\nmask_start = min(mask_start, k_seqlen \/\/ TILE_N)\n\n# Calculate where to stop (for causal, we exit early)\nif CAUSAL:\n    Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)\nelse:\n    Tc = ct.cdiv(k_seqlen, TILE_N)\n\nfor j in range(0, Tc):\n    # Load K and compute QK...\n    \n    # ONLY apply masking when necessary\n    if (CAUSAL or not EVEN_K) and j &gt;= mask_start:\n        offs_n = j * TILE_N + offs_n_tile\n        mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)\n        if not EVEN_K:\n            mask = mask &amp; (offs_n &lt; k_seqlen)\n        if CAUSAL:\n            mask = mask &amp; (offs_m &gt;= offs_n)\n        mask = ct.where(mask, 0.0, -math.inf)\n        qk += mask\n    \n    # Continue with softmax and accumulation...\n<\/pre><\/div>\n\n\n<p><strong>Why this matters<\/strong>: For a 16K sequence with 256-token tiles:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>~50% of tiles are fully unmasked (no branch, no mask computation)<\/li>\n\n\n\n<li>~1 tile per row is partially masked (full logic)<\/li>\n\n\n\n<li>The rest are skipped entirely (early exit)<\/li>\n<\/ul>\n\n\n\n<p><strong>Result in TFLOPS<\/strong>:<\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Fast math<\/strong><\/th><th><strong>Loop split<\/strong><\/th><th><strong>Improvement<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>322<\/td><td>373<\/td><td><strong>+16%<\/strong><\/td><\/tr><tr><td>2,048<\/td><td>436<\/td><td>552<\/td><td><strong>+27%<\/strong><\/td><\/tr><tr><td>4,096<\/td><td>524<\/td><td>684<\/td><td><strong>+31%<\/strong><\/td><\/tr><tr><td>8,192<\/td><td>585<\/td><td>770<\/td><td><strong>+32%<\/strong><\/td><\/tr><tr><td>16,384<\/td><td>620<\/td><td>813<\/td><td><strong>+31%<\/strong><\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 4. Performance improvement when using K-loop split optimization<\/em><\/figcaption><\/figure>\n\n\n\n<p>This is the <strong>biggest single optimization<\/strong>\u2014up to 32% speedup across all sequence lengths.<\/p>\n\n\n\n<p><strong>NCU insight (SeqLen=1,024, B200)<\/strong>:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Registers\/thread: 168 (unchanged)<\/li>\n\n\n\n<li>Theoretical\/achieved occupancy: 18.75% \/ 16.6% (unchanged)<\/li>\n\n\n\n<li>Memory throughput improves to 14.5% (less wasted work)<\/li>\n\n\n\n<li>Compute throughput remains 24.0% (work is more useful, not necessarily faster per cycle)<\/li>\n<\/ul>\n\n\n\n<h2 id=\"4_programid_remapping\"  class=\"wp-block-heading\">4. ProgramId remapping<a href=\"#4_programid_remapping\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>One subtle optimization is <strong>reversing the block order<\/strong> for causal attention. When we process tiles in reverse (bottom-right to top-left), later-launched blocks have less work due to the causal mask. This improves load balancing and reduces tail effects.<\/p>\n\n\n\n<p><strong>Before<\/strong> (standard order):<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nbid_x = ct.bid(0)  # Process tiles 0, 1, 2, ...\n<\/pre><\/div>\n\n\n<p><strong>After<\/strong> (reversed for causal):<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nif CAUSAL:\n    bid_x = NUM_M_BLOCKS - 1 - ct.bid(0)  # Process tiles N, N-1, N-2, ...\nelse:\n    bid_x = ct.bid(0)\n<\/pre><\/div>\n\n\n<p>This small change improves wave scheduling, as blocks complete more uniformly across the GPU.<\/p>\n\n\n\n<p><strong>Result in TFLOPS<\/strong>:<\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Loop split<\/strong><\/th><th><strong>Remapping<\/strong><\/th><th><strong>Improvement<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>373<\/td><td>377<\/td><td><strong>+1%<\/strong><\/td><\/tr><tr><td>2,048<\/td><td>552<\/td><td>560<\/td><td><strong>+1.5%<\/strong><\/td><\/tr><tr><td>4,096<\/td><td>684<\/td><td>696<\/td><td><strong>+1.8%<\/strong><\/td><\/tr><tr><td>8,192<\/td><td>770<\/td><td>781<\/td><td><strong>+1.5%<\/strong><\/td><\/tr><tr><td>16,384<\/td><td>813<\/td><td>835<\/td><td><strong>+2.6%<\/strong><\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 5. Performance improvement after remapping the block order of the tiles&nbsp;<\/em><\/figcaption><\/figure>\n\n\n\n<p>A modest but consistent 1-3% gain, especially noticeable at longer sequences where tail effects matter most.<\/p>\n\n\n\n<h2 id=\"5_autotuning\"  class=\"wp-block-heading\">5. Autotuning<a href=\"#5_autotuning\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>We&#8217;ve optimized large tiles, but there&#8217;s a catch: <strong>short sequences still prefer small tiles<\/strong>.<\/p>\n\n\n\n<p>Why? With a 1,024-token sequence and 256-token tiles, we only have 4 tiles. That&#8217;s not enough to fully utilize all SMs on a B200. Smaller tiles (64\u00d764) give us 16 tiles, better filling the GPU.<\/p>\n\n\n\n<p>Rather than manually choosing a threshold, we can let <strong>cuTile&#8217;s autotuner<\/strong> benchmark multiple configurations and cache the best one for each input shape.<\/p>\n\n\n\n<p><strong>The autotuner approach<\/strong>:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\ndef _fmha_autotune_configs():\n    &quot;&quot;&quot;Search space for autotuning.\n\n    The autotuner will benchmark these configurations and cache the best one\n    per input shape (sequence length, batch size, etc.).\n    &quot;&quot;&quot;\n    gpu_capability = torch.cuda.get_device_capability()\n\n    if gpu_capability in &#x5B;(12, 0), (12, 1)]:\n        # RTX 50 series (sm120, sm121)\n        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)\n    else:\n        # B200\/GB200 (sm100) - Try multiple tile sizes\n        # Autotuner will discover:\n        # - 64x64 is best for short sequences (1024-2048)\n        # - 128x128 may be best for medium sequences (4096)\n        # - 256x128 is best for long sequences (8192+)\n        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)\n        yield SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2)\n        yield SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1)\n<\/pre><\/div>\n\n\n<p><strong>How to launch with autotuning<\/strong>:<\/p>\n\n\n\n<p>Instead of calling <code>ct.launch<\/code> directly, use <code>ct_experimental.autotune_launch<\/code>:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nimport cuda.tile_experimental as ct_experimental\n\ndef autotune_launch_fmha(\n    stream, q, k, v, o, sm_scale, input_pos,\n    hidden_size, num_heads, query_group_size, is_causal\n):\n    batch_size, _, q_len, _ = q.shape\n\n    def _grid_fn(cfg):\n        return (math.ceil(q_len \/ cfg.TILE_M), batch_size * num_heads, 1)\n\n    def _args_fn(cfg):\n        num_m_blocks = math.ceil(q_len \/ cfg.TILE_M)\n        even_k = (k.shape&#x5B;2] % cfg.TILE_N) == 0\n        return (\n            q, k, v, o, sm_scale, input_pos,\n            hidden_size, num_heads, cfg.TILE_M, cfg.TILE_N,\n            query_group_size, is_causal, even_k, num_m_blocks,\n        )\n\n    ct_experimental.autotune_launch(\n        stream,\n        grid_fn=_grid_fn,\n        kernel=fmha_kernel,\n        args_fn=_args_fn,\n        hints_fn=lambda cfg: {&quot;num_ctas&quot;: cfg.num_ctas, &quot;occupancy&quot;: cfg.occupancy},\n        search_space=_fmha_autotune_configs,\n    )\n<\/pre><\/div>\n\n\n<p>Note: The autotuner API may be subject to change.<\/p>\n\n\n\n<p>The autotuner works intelligently:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>First call with seq_len=1024<\/strong>: Benchmarks all 3 configs, caches best one<\/li>\n\n\n\n<li><strong>First call with seq_len=2048<\/strong>: Benchmarks all 3 configs, caches best one<\/li>\n\n\n\n<li><strong>Subsequent calls<\/strong>: Uses cached config (zero overhead)<\/li>\n<\/ol>\n\n\n\n<p>The cache key includes tensor shapes, so different sequence lengths automatically get different optimal configurations.<\/p>\n\n\n\n<p><strong>Result in TFLOPS<\/strong>:<\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>SeqLen<\/strong><\/th><th><strong>Baseline<\/strong><\/th><th><strong>Remapping<\/strong><\/th><th><strong>Autotune<\/strong><\/th><th><strong>Speedup vs baseline<\/strong><\/th><\/tr><\/thead><tbody><tr><td>1,024<\/td><td>330<\/td><td>377<\/td><td><strong>548<\/strong><\/td><td><strong>1.66x<\/strong><\/td><\/tr><tr><td>2,048<\/td><td>441<\/td><td>560<\/td><td><strong>708<\/strong><\/td><td><strong>1.61x<\/strong><\/td><\/tr><tr><td>4,096<\/td><td>511<\/td><td>696<\/td><td><strong>817<\/strong><\/td><td><strong>1.60x<\/strong><\/td><\/tr><tr><td>8,192<\/td><td>546<\/td><td>781<\/td><td><strong>887<\/strong><\/td><td><strong>1.62x<\/strong><\/td><\/tr><tr><td>16,384<\/td><td>566<\/td><td>835<\/td><td><strong>918<\/strong><\/td><td><strong>1.62x<\/strong><\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 6. Original baseline compared to step 5 and to step 6 autotuned results<\/em><\/figcaption><\/figure>\n\n\n\n<p>The autotuner discovers that 64\u00d764 tiles are best for sequences \u22642,048, then transitions to larger tiles for longer sequences. This delivers <strong>45% additional performance<\/strong> at short sequences compared to fixed large tiles, while maintaining peak performance at long sequences.<\/p>\n\n\n\n<p><strong>What the autotuner chose<\/strong> (on B200):<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>SeqLen 1,024: 64\u00d764 tiles (high parallelism)<\/li>\n\n\n\n<li>SeqLen 2,048: 64\u00d764 or 128\u00d7128 tiles (balanced)<\/li>\n\n\n\n<li>SeqLen 4,096+: 128\u00d7128 or 256\u00d7128 tiles (memory efficiency)<\/li>\n<\/ul>\n\n\n\n<p>We now achieve optimal performance across all sequence lengths without manual tuning.<\/p>\n\n\n\n<h2 id=\"summary_the_optimization_stack\"  class=\"wp-block-heading\">Summary: The optimization stack<a href=\"#summary_the_optimization_stack\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table><thead><tr><th><strong>Optimization<\/strong><\/th><th><strong>Key insight<\/strong><\/th><th><strong>Impact<\/strong><\/th><\/tr><\/thead><tbody><tr><td>Baseline (64\u00d764)<\/td><td>Correct but unoptimized<\/td><td>Baseline<\/td><\/tr><tr><td>Large tiles (256\u00d7128)<\/td><td><strong>TRAP<\/strong>: 18-43% slower!<\/td><td>-18% to -43%<\/td><\/tr><tr><td>+ Fast math (FTZ, APPROX)<\/td><td><strong>RESCUE<\/strong>: Large tiles now pay off<\/td><td>+34% to +72% from trap<\/td><\/tr><tr><td>+ K-loop split<\/td><td>Biggest single optimization<\/td><td>+16% to +32%<\/td><\/tr><tr><td>+ ProgramId remapping<\/td><td>Better load balancing<\/td><td>+1% to +3%<\/td><\/tr><tr><td>+ Autotuning<\/td><td>Optimal tiles per sequence<\/td><td>+10% to +45%<\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 7. Step-by-step optimization results with performance impacts for each step<\/em><\/figcaption><\/figure>\n\n\n\n<p><strong>Final speedup: 1.60x-1.66x<\/strong> across all sequence lengths.<\/p>\n\n\n\n<h2 id=\"getting_started\"  class=\"wp-block-heading\">Getting started<a href=\"#getting_started\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Writing high-performance kernels is rarely about finding one &#8220;magic&#8221; setting. As we saw with the &#8220;trap and rescue&#8221;:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>Optimizations are interdependent<\/strong>: Large tiles were <em>slower<\/em> until we fixed the math. You can&#8217;t evaluate tile size in isolation.<\/li>\n\n\n\n<li><strong>Math matters<\/strong>: Flags like <code>flush_to_zero<\/code> and <code>APPROX<\/code> are critical for unlocking Tensor Core throughput. Precise math is often overkill for deep learning.<\/li>\n\n\n\n<li><strong>Algorithmic wins compound<\/strong>: K-loop splitting gave us the biggest single improvement (up to 32%) by avoiding unnecessary work.<\/li>\n\n\n\n<li><strong>Autotuning beats manual heuristics<\/strong>: cuTile&#8217;s autotuner discovers optimal tile sizes per sequence length (64\u00d764 for short sequences, 256\u00d7128 for long), delivering 10-45% gains over fixed configurations.<\/li>\n\n\n\n<li><strong>Cumulative effects are multiplicative<\/strong>: The full optimization stack delivers <strong>1.60x-1.66x speedup<\/strong> across all sequence lengths\u2014far more than any single optimization alone.<\/li>\n<\/ol>\n\n\n\n<p><a href=\"https:\/\/github.com\/NVIDIA\/cutile-python\"><strong>cuTile<\/strong><\/a> enables developers to express these optimizations\u2014tiling, fast math controls, loop splitting, autotune\u2014in clean, readable Python code while generating highly optimized PTX for NVIDIA GPUs.<\/p>\n\n\n\n<p>You can find the completely optimized kernel in the <a href=\"https:\/\/github.com\/NVIDIA\/TileGym\">TileGym repository<\/a>. Happy hacking.<\/p>\n","protected":false},"excerpt":{"rendered":"<p>In this post, we dive into one of the most critical workloads in modern AI: Flash Attention, where you\u2019ll learn: Environment requirements: See the quickstart doc for more information on installing cuTile Python. What is attention? The attention mechanism is the computational heart of transformer models. Given a sequence of tokens, attention enables each token &hellip; <a href=\"https:\/\/developer.nvidia.com\/blog\/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile\/\">Continued<\/a><\/p>\n","protected":false},"author":3205,"featured_media":113250,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"publish_to_discourse":"","publish_post_category":"318","wpdc_auto_publish_overridden":"1","wpdc_topic_tags":"","wpdc_pin_topic":"","wpdc_pin_until":"","discourse_post_id":"1768517","discourse_permalink":"https:\/\/forums.developer.nvidia.com\/t\/tuning-flash-attention-for-peak-performance-in-nvidia-cuda-tile\/362452","wpdc_publishing_response":"success","wpdc_publishing_error":"","nv_subtitle":"","ai_post_summary":"","footnotes":"","_links_to":"","_links_to_target":""},"categories":[3110,696,4146,1903],"tags":[4897,4896,453],"coauthors":[5035,5036,5037,1635],"class_list":["post-113179","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-generative-ai","category-data-science","category-development","category-features","tag-cuda-tile","tag-cutile","tag-featured","tagify_workload-data-science"],"acf":{"post_industry":["General"],"post_products":["CUDA"],"post_learning_levels":["Advanced Technical"],"post_content_types":["Tutorial"],"post_collections":""},"jetpack_featured_media_url":"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/CUDA-Tile-Flash-Attention.webp","primary_category":{"category":"Developer Tools &amp; Techniques","link":"https:\/\/developer.nvidia.com\/blog\/category\/development\/","id":4146,"data_source":""},"nv_translations":[{"language":"zh_CN","title":"\u5728 NVIDIA CUDA Tile \u4e2d\u8c03\u6574 Flash Attention \u4ee5\u5b9e\u73b0\u5cf0\u503c\u6027\u80fd","post_id":16844}],"jetpack_shortlink":"https:\/\/wp.me\/pcCQAL-trt","jetpack_likes_enabled":true,"jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113179","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/users\/3205"}],"replies":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/comments?post=113179"}],"version-history":[{"count":61,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113179\/revisions"}],"predecessor-version":[{"id":115857,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113179\/revisions\/115857"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media\/113250"}],"wp:attachment":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media?parent=113179"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/categories?post=113179"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/tags?post=113179"},{"taxonomy":"author","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/coauthors?post=113179"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}