{"id":112140,"date":"2026-02-03T09:30:00","date_gmt":"2026-02-03T17:30:00","guid":{"rendered":"https:\/\/developer.nvidia.com\/blog\/?p=112140"},"modified":"2026-03-05T11:20:03","modified_gmt":"2026-03-05T19:20:03","slug":"accelerating-long-context-model-training-in-jax-and-xla","status":"publish","type":"post","link":"https:\/\/developer.nvidia.com\/blog\/accelerating-long-context-model-training-in-jax-and-xla\/","title":{"rendered":"Accelerating Long-Context Model Training in JAX and XLA"},"content":{"rendered":"\n<p><a href=\"https:\/\/www.nvidia.com\/en-us\/glossary\/large-language-models\/\">Large language models (LLMs)<\/a> are rapidly expanding their context windows, with recent models supporting sequences of 128K tokens, 256K tokens, and beyond. However, training these models with extended context lengths presents significant computational and communication challenges. As context lengths grow, the memory and communication overhead of attention mechanisms scale quadratically, creating bottlenecks that traditional parallelism strategies struggle to address efficiently.<\/p>\n\n\n\n<p>This post demonstrates that integrating the <a href=\"https:\/\/developer.nvidia.com\/nvshmem\">NVSHMEM<\/a> communication library into Accelerated Linear Algebra (XLA) compiler optimizes context parallelism. This integration enables the efficient training of Llama 3 8B model in JAX framework with sequences up to 256K tokens. Our results show that NVSHMEM provides up to 36% speedup over <a href=\"https:\/\/developer.nvidia.com\/nccl\">NVIDIA Collective Communications Library (NCCL)<\/a> for long-context training workloads, particularly when combined with tensor parallelism across multiple nodes.<\/p>\n\n\n\n<h2 id=\"the_long-context_training_challenge\"  class=\"wp-block-heading\">The long-context training challenge<a href=\"#the_long-context_training_challenge\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>To understand why NVSHMEM provides significant speedups for long-context training, it\u2019s necessary to first understand how context parallelism works and the unique communication patterns it creates. This section explains why the fine-grained, latency-sensitive communication of ring attention makes it an ideal candidate for optimization.<\/p>\n\n\n\n<h3 id=\"context_parallelism_and_ring_attention\"  class=\"wp-block-heading\">Context parallelism and ring attention<a href=\"#context_parallelism_and_ring_attention\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><em>Context parallelism <\/em>(CP) is a parallelization strategy designed specifically for handling long sequences in transformer models. Unlike data parallelism, which splits the batch, or tensor parallelism, which splits the model, context parallelism splits the sequence dimension across multiple devices.<\/p>\n\n\n\n<p><em>Ring attention<\/em> is a memory-efficient implementation of context parallelism that uses a ring-based communication pattern. During attention computation, each device:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Processes its local portion of the sequence<\/li>\n\n\n\n<li>Exchanges Key Value (KV) tensors with neighboring devices in a ring topology<\/li>\n\n\n\n<li>Incrementally computes attention scores as KV blocks circulate around the ring<\/li>\n<\/ul>\n\n\n\n<p>This approach reduces peak memory usage while maintaining mathematical equivalence to standard attention, making it possible to train with sequences that would otherwise exceed GPU memory capacity.<\/p>\n\n\n\n<h3 id=\"communication_patterns_in_ring_attention\"  class=\"wp-block-heading\">Communication patterns in ring attention<a href=\"#communication_patterns_in_ring_attention\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Ring attention involves frequent, fine-grained communication operations:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Point-to-point transfers<\/strong>: Sending KV tensors to the next device in the ring<\/li>\n\n\n\n<li><strong>Overlapped compute-communication<\/strong>: Computing attention on current KV blocks while fetching the next blocks<\/li>\n\n\n\n<li><strong>Low-latency requirement: <\/strong>KV transfers are on the critical path and must complete before attention can proceed<\/li>\n<\/ul>\n\n\n\n<p>These characteristics make ring attention an ideal candidate for low-latency communication libraries like NVSHMEM.<\/p>\n\n\n\n<h2 id=\"gpu-optimized_communication_with_nvshmem\"  class=\"wp-block-heading\">GPU-optimized communication with NVSHMEM<a href=\"#gpu-optimized_communication_with_nvshmem\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p><a href=\"https:\/\/developer.nvidia.com\/nvshmem\">NVSHMEM<\/a> is a communication library that implements the OpenSHMEM parallel programming model for NVIDIA GPUs. It provides several key features that distinguish it from traditional communication libraries, including symmetric memory (SM), stream-aware communication, copy engine offloading, and more, as detailed below.<\/p>\n\n\n\n<h3 id=\"symmetric_memory\"  class=\"wp-block-heading\">Symmetric memory<a href=\"#symmetric_memory\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>NVSHMEM provides a partitioned global address space (PGAS) resident in GPUs memories. Applications allocate buffers from this symmetric heap using <code>nvshmem_malloc<\/code>, and these pointers can be directly used in communication operations. For example:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: cpp; title: ; notranslate\" title=\"\">\nint32_t *src_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));\nint32_t *dest_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));\nret = nvshmemx_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dest_d, src_d, 1024, 0);\n<\/pre><\/div>\n\n\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;69efb75975c95&quot;}\" data-wp-interactive=\"core\/image\" class=\"wp-block-image aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1999\" height=\"583\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on-async--click=\"actions.showLightbox\" data-wp-on-async--load=\"callbacks.setButtonStyles\" data-wp-on-async-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem.png\" alt=\"Symmetric memory regions (shared) and private memory regions at each PE. The aggregation of the shared memory segments across all PEs is referred to as a partitioned global address space (PGAS).\n\" class=\"wp-image-112167\" srcset=\"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem.png 1999w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-179x52.png 179w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-300x87.png 300w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-768x224.png 768w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-625x182.png 625w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-1536x448.png 1536w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-645x188.png 645w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-500x146.png 500w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-160x47.png 160w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-362x106.png 362w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-377x110.png 377w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-1024x299.png 1024w, https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/symmetric-memory-heap-nvshmem-960x280.png 960w\" sizes=\"auto, (max-width: 1999px) 100vw, 1999px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on-async--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 1. Symmetric memory heap in NVSHMEM<\/em><\/figcaption><\/figure>\n\n\n\n<h3 id=\"stream-aware_communication\"  class=\"wp-block-heading\">Stream-aware communication<a href=\"#stream-aware_communication\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>NVSHMEM provides peer-to-peer (P2P) on-stream APIs (such as <code>put_nbi_on_stream<\/code> and <code>signal_on_stream<\/code>) to efficiently move data and provide low-latency synchronization over P2P-connected GPUs.&nbsp;<\/p>\n\n\n\n<p>One of the key advantages of these APIs over traditional host-initiated communication is their ability to perform these operations through a zero-SM footprint by leveraging the copy-engine (CE) and <a href=\"https:\/\/docs.nvidia.com\/cuda\/cuda-driver-api\/group__CUDA__MEMOP.html\">stream memory operations<\/a> capabilities of GPU hardware. Some of the underlying CUDA interfaces include:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Direct GPU-to-GPU transfers<\/strong>: Similar to <code>cudaMemcpyAsync<\/code>, but with lower latency through optimized data paths<\/li>\n\n\n\n<li><strong>Fine-grained synchronization<\/strong>: Using <code>cuStreamWriteValue32<\/code> and <code>cuStreamWaitValue32<\/code> primitives for efficient signaling between devices without CPU involvement<\/li>\n<\/ul>\n\n\n\n<p>In addition to the P2P on-stream APIs, NVSHMEM also provides popular collective operations (<code>reduce_on_stream<\/code>, for example)\u00a0 commonly used in AI workloads such as AllReduce. These collectives leverage SHARP, in-network reductions, and multicast acceleration features of <a href=\"https:\/\/www.nvidia.com\/en-us\/data-center\/nvlink\/\">NVIDIA NVLINK Switch<\/a> to enable latency-optimized one-shot and throughput-optimized two-shot AllReduce algorithms. The underlying CUDA interface includes <a href=\"https:\/\/docs.nvidia.com\/cuda\/parallel-thread-execution\/#data-movement-and-conversion-instructions-multimem\">multimem ISA<\/a>, providing additional benefits of a reduced-SM footprint as primitives such as reductions and broadcast are offloaded to the switch.<\/p>\n\n\n\n<p>Both of these features can demonstrate useful compute-communication operations pipelining as most or all of the GPU SMs are available for compute operations, when overlapped in time on the same CUDA stream.<\/p>\n\n\n\n<h3 id=\"cuda_graphs_interoperability\"  class=\"wp-block-heading\">CUDA Graphs interoperability<a href=\"#cuda_graphs_interoperability\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>NVSHMEM operations can be captured into <a href=\"https:\/\/docs.nvidia.com\/cuda\/cuda-programming-guide\/04-special-topics\/cuda-graphs.html#\">CUDA Graphs<\/a>, enabling:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Amortized kernel launch overhead across multiple iterations<\/li>\n\n\n\n<li>Optimized execution scheduling by the CUDA runtime<\/li>\n\n\n\n<li>Seamless composition with other graph-captured operations<\/li>\n<\/ul>\n\n\n\n<p>This composability is crucial for production training frameworks that rely on CUDA Graphs for performance optimization.<\/p>\n\n\n\n<h2 id=\"integrating_nvshmem_and_xla\"  class=\"wp-block-heading\">Integrating NVSHMEM and XLA<a href=\"#integrating_nvshmem_and_xla\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>This section describes how NVSHMEM is integrated into the XLA compiler infrastructure, covering runtime flags, automatic backend selection heuristics, and the compilation flow.<\/p>\n\n\n\n<h3 id=\"runtime_control_through_debug_options\"  class=\"wp-block-heading\">Runtime control through debug options<a href=\"#runtime_control_through_debug_options\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>XLA exposes a runtime flag for dynamic control:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: plain; title: ; notranslate\" title=\"\">\nXLA_FLAGS=&quot;--xla_gpu_experimental_enable_nvshmem=true&quot;\n<\/pre><\/div>\n\n\n<p>This flag is defined in <code>xla\/debug_options_flags.cc<\/code> and allows users to enable or disable NVSHMEM without recompilation (default value = false). The &#8220;experimental&#8221; prefix indicates that the API may evolve as the feature matures.<\/p>\n\n\n\n<h3 id=\"automatic_backend_selection\"  class=\"wp-block-heading\">Automatic backend selection<a href=\"#automatic_backend_selection\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The <code>CollectiveBackendAssigner<\/code> pass in the compilation pipeline determines which communication backend to use based on workload characteristics. This is where the intelligence of this system lies.<\/p>\n\n\n\n<h4 class=\"wp-block-heading\">Selection heuristics<\/h4>\n\n\n\n<p>The compiler analyzes each collective operation and decides whether to use NVSHMEM based on three key criteria:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>Single device<\/strong>: Use NVSHMEM when only one device is visible per process (no network overhead)<\/li>\n\n\n\n<li><strong>Single partition<\/strong>: Use NVSHMEM when all participating devices in the collective operation are managed by the same process<\/li>\n\n\n\n<li><strong>NVLink domain<\/strong>: Use NVSHMEM for intranode communication over <a href=\"https:\/\/www.nvidia.com\/en-us\/data-center\/nvlink\/\">NVIDIA NVLink<\/a> fabric<\/li>\n<\/ol>\n\n\n\n<p>Additionally, message size heuristics apply:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>AllReduce operations<\/strong>: Only use NVSHMEM if message size &lt; threshold (typically 16 MB). For larger messages, fall back to NCCL which is optimized for bandwidth.<\/li>\n\n\n\n<li><strong>CollectivePermute operations<\/strong>: Always use NVSHMEM regardless of message size (no threshold applied).<\/li>\n\n\n\n<li><strong>Rationale<\/strong>: AllReduce benefits from NCCL ring or tree algorithms for large messages, while CollectivePermute point-to-point nature makes NVSHMEM low latency ideal at any size.<\/li>\n<\/ul>\n\n\n\n<h3 id=\"jax_framework_integration\"  class=\"wp-block-heading\">JAX framework integration<a href=\"#jax_framework_integration\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The strength of this architecture lies in its complete transparency to Python frameworks. A JAX developer writes standard collective operations:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nimport jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef collective_permute_example(x):\n    # Shift data from each device to the next device in a ring\n    axis_name = &#039;devices&#039;\n    perm = &#x5B;(i, (i + 1) % jax.device_count()) for i in range(jax.device_count())]\n    return jax.lax.ppermute(x, axis_name, perm=perm)\n\n# The compiler automatically selects NVSHMEM when appropriate\nresult = collective_permute_example(data)\n<\/pre><\/div>\n\n\n<p>The XLA compiler analyzes this <code>ppermute<\/code> (collective permute) operation and automatically with the following steps:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Applies heuristics: single device, single partition, or within NVLink domain<\/li>\n\n\n\n<li>Recognizes a CollectivePermute operation (no message size threshold applies)<\/li>\n\n\n\n<li>Selects NVSHMEM for optimal point-to-point communication<\/li>\n\n\n\n<li>Generates thunks that invoke NVSHMEM host APIs at runtime<\/li>\n\n\n\n<li>NVSHMEM host APIs enqueue operations on CUDA streams. For example:&nbsp; <code>nvshmemx_float_sum_reduce_on_stream<\/code>, <code>nvshmemx_float_put_nbi_on_stream<\/code><\/li>\n<\/ul>\n\n\n\n<p>This end-to-end integration means that high-level JAX code automatically benefits from NVSHMEM performance without requiring any user-level changes or annotations.<\/p>\n\n\n\n<h2 id=\"experimental_methodology\"  class=\"wp-block-heading\">Experimental methodology<a href=\"#experimental_methodology\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>To evaluate NVSHMEM performance benefits, the team conducted experiments on Llama 3 8B across a range of sequence lengths (64K to 256K tokens) and parallelism configurations. This section details the model setup, hardware configuration, and the metrics used to compare NVSHMEM against the NCCL baseline.<\/p>\n\n\n\n<h3 id=\"model_configuration\"  class=\"wp-block-heading\">Model configuration<a href=\"#model_configuration\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The team evaluated NVSHMEM-accelerated context parallelism on the Llama 3 8B model with the following configurations.<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Model<\/strong>: Llama 3 8B<\/li>\n\n\n\n<li><strong>Precision<\/strong>: BF16<\/li>\n\n\n\n<li><strong>Context parallel strategy<\/strong>: Ring attention<\/li>\n\n\n\n<li><strong>Framework<\/strong>: MaxText (JAX-based training framework)<\/li>\n\n\n\n<li><strong>Hardware<\/strong>: <a href=\"https:\/\/www.nvidia.com\/en-us\/data-center\/gb200-nvl72\/\">NVIDIA GB200 NVL72<\/a><\/li>\n\n\n\n<li><strong>Docker image<\/strong>: Available through <a href=\"https:\/\/github.com\/nvidia\/JAX-Toolbox\/pkgs\/container\/jax\">NVIDIA\/JAX-Toolbox<\/a> <\/li>\n\n\n\n<li><strong>JAX version<\/strong>: JAX 0.6.2 and later&nbsp;<\/li>\n<\/ul>\n\n\n\n<h3 id=\"parallelism_configurations\"  class=\"wp-block-heading\">Parallelism configurations<a href=\"#parallelism_configurations\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Various combinations of parallelism strategies were tested across different sequence lengths (Table 1).<\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table class=\"has-fixed-layout\"><tbody><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>Sequence length<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Nodes<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>GPUs<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Context parallelism<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Tensor parallelism<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Fully sharded data parallelism<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Sequence length per GPU after CP split<\/strong><\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\">64K<\/td><td class=\"has-text-align-center\" data-align=\"center\">1-4<\/td><td class=\"has-text-align-center\" data-align=\"center\">4-16<\/td><td class=\"has-text-align-center\" data-align=\"center\">4-16<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">1-2<\/td><td class=\"has-text-align-center\" data-align=\"center\">4K-16K<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\">128K<\/td><td class=\"has-text-align-center\" data-align=\"center\">2-8<\/td><td class=\"has-text-align-center\" data-align=\"center\">8-32<\/td><td class=\"has-text-align-center\" data-align=\"center\">8-32<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">1-2<\/td><td class=\"has-text-align-center\" data-align=\"center\">4K-16K<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\">256K<\/td><td class=\"has-text-align-center\" data-align=\"center\">8-16<\/td><td class=\"has-text-align-center\" data-align=\"center\">32-64<\/td><td class=\"has-text-align-center\" data-align=\"center\">16-32<\/td><td class=\"has-text-align-center\" data-align=\"center\">2<\/td><td class=\"has-text-align-center\" data-align=\"center\">1-2<\/td><td class=\"has-text-align-center\" data-align=\"center\">8K-16K<\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 1. Parallelism configurations tested across different sequence lengths<\/em><\/figcaption><\/figure>\n\n\n\n<p>Longer sequences (256K) employed tensor parallelism (TP=2) in addition to context parallelism to fit the model within GPU memory constraints.<\/p>\n\n\n\n<h3 id=\"communication_backend_comparison\"  class=\"wp-block-heading\">Communication backend comparison<a href=\"#communication_backend_comparison\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Each configuration was evaluated with two communication backends:<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>NCCL (baseline)<\/li>\n\n\n\n<li>NVSHMEM-enabled implementation<\/li>\n<\/ol>\n\n\n\n<p>Measurements:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>TFLOP\/s per device<\/strong>: GPU computational throughput<\/li>\n\n\n\n<li><strong>Step time (seconds)<\/strong>: Time per training iteration<\/li>\n\n\n\n<li><strong>Speedup<\/strong>: Relative performance improvement of NVSHMEM over NCCL<\/li>\n<\/ul>\n\n\n\n<p>All metrics were averaged across iterations 3-20 (skipping the first two warmup iterations) and computed from rank 0 logs to ensure consistency.<\/p>\n\n\n\n<h2 id=\"performance_results\"  class=\"wp-block-heading\">Performance results<a href=\"#performance_results\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>As shown in Table 2, the NVSHMEM performance advantage grows significantly with sequence length:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>64K sequences<\/strong>: 0.3-3.9% speedup (modest improvement)<\/li>\n\n\n\n<li><strong>128K sequence<\/strong>s: 0.7-2.4% speedup (consistent improvement)<\/li>\n\n\n\n<li><strong>256K sequences<\/strong>: 30.4-36.3% speedup (dramatic improvement)<\/li>\n<\/ul>\n\n\n\n<p>This scaling behavior aligns with the ring attention communication pattern: longer sequences require more KV tensor exchanges around the ring, amplifying the benefits of the NVSHMEM lower-latency communication.<\/p>\n\n\n\n<p>When scaling across nodes, internode communication latency becomes more critical. NVSHMEM nonblocking host APIs and optimized data paths provide consistent benefits across 8-16 node deployments.<\/p>\n\n\n\n<figure class=\"wp-block-table aligncenter\"><table class=\"has-fixed-layout\"><tbody><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>Sequence length<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Nodes<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>CP<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>TP<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>GPUs<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Seq\/GPU<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Default TFLOP\/s<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>NVSHMEM TFLOP\/s<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\"><strong>Speedup<\/strong><\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>64K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">4<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">4<\/td><td class=\"has-text-align-center\" data-align=\"center\">16K<\/td><td class=\"has-text-align-center\" data-align=\"center\">605.64<\/td><td class=\"has-text-align-center\" data-align=\"center\">607.36<\/td><td class=\"has-text-align-center\" data-align=\"center\">+0.3%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>64K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">2<\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">8K<\/td><td class=\"has-text-align-center\" data-align=\"center\">549.92<\/td><td class=\"has-text-align-center\" data-align=\"center\">557.17<\/td><td class=\"has-text-align-center\" data-align=\"center\">+1.3%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>64K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">4<\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">4K<\/td><td class=\"has-text-align-center\" data-align=\"center\">482.19<\/td><td class=\"has-text-align-center\" data-align=\"center\">501.06<\/td><td class=\"has-text-align-center\" data-align=\"center\">+3.9%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>128K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">2<\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">16K<\/td><td class=\"has-text-align-center\" data-align=\"center\">512.22<\/td><td class=\"has-text-align-center\" data-align=\"center\">515.87<\/td><td class=\"has-text-align-center\" data-align=\"center\">+0.7%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>128K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">4<\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">8K<\/td><td class=\"has-text-align-center\" data-align=\"center\">473.58<\/td><td class=\"has-text-align-center\" data-align=\"center\">472.46<\/td><td class=\"has-text-align-center\" data-align=\"center\">-0.2%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>128K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">32<\/td><td class=\"has-text-align-center\" data-align=\"center\">1<\/td><td class=\"has-text-align-center\" data-align=\"center\">32<\/td><td class=\"has-text-align-center\" data-align=\"center\">4K<\/td><td class=\"has-text-align-center\" data-align=\"center\">420.99<\/td><td class=\"has-text-align-center\" data-align=\"center\">431.13<\/td><td class=\"has-text-align-center\" data-align=\"center\">+2.4%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>256K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">8<\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">2<\/td><td class=\"has-text-align-center\" data-align=\"center\">32<\/td><td class=\"has-text-align-center\" data-align=\"center\">16K<\/td><td class=\"has-text-align-center\" data-align=\"center\">366.94<\/td><td class=\"has-text-align-center\" data-align=\"center\">500.22<\/td><td class=\"has-text-align-center\" data-align=\"center\">+36.3%<\/td><\/tr><tr><td class=\"has-text-align-center\" data-align=\"center\"><strong>256K<\/strong><\/td><td class=\"has-text-align-center\" data-align=\"center\">16<\/td><td class=\"has-text-align-center\" data-align=\"center\">32<\/td><td class=\"has-text-align-center\" data-align=\"center\">2<\/td><td class=\"has-text-align-center\" data-align=\"center\">64<\/td><td class=\"has-text-align-center\" data-align=\"center\">8K<\/td><td class=\"has-text-align-center\" data-align=\"center\">346.33<\/td><td class=\"has-text-align-center\" data-align=\"center\">451.70<\/td><td class=\"has-text-align-center\" data-align=\"center\">+30.4%<\/td><\/tr><\/tbody><\/table><figcaption class=\"wp-element-caption\"><em>Table 2. Performance comparison of default (NCCL) and NVSHMEM across different configurations<\/em><\/figcaption><\/figure>\n\n\n\n<h3 id=\"practical_implications\"  class=\"wp-block-heading\">Practical implications<a href=\"#practical_implications\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Based on these results, NVSHMEM provides clear advantages for:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Long-context training<\/strong>: Sequences \u2265 128K tokens where communication becomes a bottleneck<\/li>\n\n\n\n<li><strong>Multinode deployments<\/strong>: Scaling beyond single-node NVLink domains<\/li>\n\n\n\n<li><strong>Ring attention and similar patterns<\/strong>: Workloads with fine-grained, latency-sensitive communication<\/li>\n\n\n\n<li><strong>Hybrid parallelism<\/strong>: Configurations combining CP, TP, and FSDP<\/li>\n<\/ul>\n\n\n\n<p>The XLA integration makes NVSHMEM accessible to JAX. No user code changes are required, simply use an NVSHMEM-enabled XLA build and set the appropriate environment flags.<\/p>\n\n\n\n<h2 id=\"get_started_with_long-context_model_training&nbsp;\"  class=\"wp-block-heading\">Get started with long-context model training&nbsp;<a href=\"#get_started_with_long-context_model_training&nbsp;\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Training LLMs with long-context windows requires efficient communication strategies that can handle fine-grained, latency-sensitive data exchanges. The integration of NVSHMEM into XLA enables transparent acceleration of context parallelism with ring attention, providing up to 36% speedup for 256K token sequences on Llama 3 8B.<\/p>\n\n\n\n<p>Key takeaways:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>The NVSHMEM nonblocking host APIs and low-latency data paths are ideally suited for the ring attention communication pattern<\/li>\n\n\n\n<li>XLA compiler integration makes NVSHMEM accessible to high-level frameworks without requiring code changes<\/li>\n\n\n\n<li>Performance benefits scale with sequence length, with dramatic improvements for sequences \u2265 256K tokens<\/li>\n\n\n\n<li>Multinode deployments see the largest gains, making NVSHMEM essential for production long-context training<\/li>\n<\/ul>\n\n\n\n<p>As context windows continue to grow, solutions optimizing low-latency communication like NVSHMEM will be crucial for making long-context training practical and cost-effective. We encourage the community to try NVSHMEM-enabled XLA builds in JAX framework and share their experiences with long-context workloads.<\/p>\n\n\n\n<p>To get started, check out <a href=\"https:\/\/github.com\/google\/maxtext\">MaxText Framework,<\/a> <a href=\"https:\/\/github.com\/nvidia\/JAX-Toolbox\/pkgs\/container\/jax\">NVIDIA\/JAX-Toolbox<\/a>, and <a href=\"https:\/\/github.com\/openxla\/xla\">openxla\/xla<\/a> on GitHub.<\/p>\n\n\n\n<h3 id=\"acknowledgments\"  class=\"wp-block-heading\">Acknowledgments<a href=\"#acknowledgments\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><em>We would like to express our gratitude to NVSHMEM contributors Seth Howell and Akhil Langer.<\/em><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Large language models (LLMs) are rapidly expanding their context windows, with recent models supporting sequences of 128K tokens, 256K tokens, and beyond. However, training these models with extended context lengths presents significant computational and communication challenges. As context lengths grow, the memory and communication overhead of attention mechanisms scale quadratically, creating bottlenecks that traditional parallelism &hellip; <a href=\"https:\/\/developer.nvidia.com\/blog\/accelerating-long-context-model-training-in-jax-and-xla\/\">Continued<\/a><\/p>\n","protected":false},"author":3167,"featured_media":112286,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"publish_to_discourse":"","publish_post_category":"318","wpdc_auto_publish_overridden":"1","wpdc_topic_tags":"","wpdc_pin_topic":"","wpdc_pin_until":"","discourse_post_id":"1754731","discourse_permalink":"https:\/\/forums.developer.nvidia.com\/t\/accelerating-long-context-model-training-in-jax-and-xla\/359562","wpdc_publishing_response":"success","wpdc_publishing_error":"","nv_subtitle":"","ai_post_summary":"<ul><li>Integrating NVSHMEM with the XLA compiler and JAX enables efficient training of Llama 3 8B on sequences up to 256K tokens, yielding up to 36% speedup over NVIDIA NCCL for long-context workloads, especially when combined with tensor parallelism across multiple nodes.<\/li><li>NVSHMEM provides GPU-optimized features such as symmetric memory, stream-aware communication, and CUDA Graphs interoperability, which are leveraged by XLA&#039;s backend selection heuristics to optimize ring attention&#039;s fine-grained, latency-sensitive communication patterns in context parallelism.<\/li><li>Performance gains from NVSHMEM scale with sequence length and are most pronounced in multinode deployments and hybrid parallelism configurations, making NVSHMEM essential for production long-context LLM training using JAX and XLA without requiring user code modifications.<\/li><\/ul>","footnotes":"","_links_to":"","_links_to_target":""},"categories":[3110,4146,1205],"tags":[2499,453,3650,3596],"coauthors":[4996,4700,4997,4699,4008],"class_list":["post-112140","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-generative-ai","category-development","category-networking-communications","tag-cuda-graphs","tag-featured","tag-llm-techniques","tag-training-ai-models","tagify_workload-generative-ai","tagify_workload-data-center-cloud","tagify_workload-data-science","tagify_workload-networking-communications"],"acf":{"post_industry":["General"],"post_products":["NCCL","NVLink","NVSHMEM","NVSwitch"],"post_learning_levels":["Intermediate Technical"],"post_content_types":["Tutorial"],"post_collections":""},"jetpack_featured_media_url":"https:\/\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/01\/llm-cloud-icons.png","primary_category":{"category":"Agentic AI \/ Generative AI","link":"https:\/\/developer.nvidia.com\/blog\/category\/generative-ai\/","id":3110,"data_source":""},"nv_translations":[],"jetpack_shortlink":"https:\/\/wp.me\/pcCQAL-taI","jetpack_likes_enabled":true,"jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/112140","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/users\/3167"}],"replies":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/comments?post=112140"}],"version-history":[{"count":18,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/112140\/revisions"}],"predecessor-version":[{"id":112342,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/112140\/revisions\/112342"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media\/112286"}],"wp:attachment":[{"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media?parent=112140"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/categories?post=112140"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/tags?post=112140"},{"taxonomy":"author","embeddable":true,"href":"https:\/\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/coauthors?post=112140"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}