<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://danielhuangjiakang.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://danielhuangjiakang.github.io/" rel="alternate" type="text/html" /><updated>2026-04-01T07:18:28+00:00</updated><id>https://danielhuangjiakang.github.io/feed.xml</id><title type="html">Jiakang Huang’s Website</title><subtitle>Jiakang Huang&apos;s personal academic and research website</subtitle><author><name>Jiakang Huang</name><email>jhuang74@student.ubc.ca</email></author><entry xml:lang="zh"><title type="html">PyTorch Inductor 中 speedup_by_fusion 深度解析</title><link href="https://danielhuangjiakang.github.io/zh/blog/speedup-by-fusion-pytorch-inductor/" rel="alternate" type="text/html" title="PyTorch Inductor 中 speedup_by_fusion 深度解析" /><published>2026-03-31T00:00:00+00:00</published><updated>2026-03-31T00:00:00+00:00</updated><id>https://danielhuangjiakang.github.io/zh/blog/speedup-by-fusion-pytorch-inductor-CN</id><content type="html" xml:base="https://danielhuangjiakang.github.io/zh/blog/speedup-by-fusion-pytorch-inductor/"><![CDATA[<p><strong>作者：</strong> Jiakang Huang</p>

<figure class="post-feature-image">
  <img src="/images/speedup-by-fusion-cover.png" alt="PyTorch Inductor 中 speedup_by_fusion 的封面图" />
</figure>

<p>上篇文章我们分析了 Inductor 中 <code class="language-plaintext highlighter-rouge">fuse_nodes</code> 的整体架构和工作流程（详见：<a href="/zh/blog/fuse-nodes-pytorch-inductor/">PyTorch Inductor 中 fuse_nodes 融合流程深度解析</a>）。本篇我们将聚焦其中一个有趣的配置项 <code class="language-plaintext highlighter-rouge">speedup_by_fusion</code>，从开启方式、运行机制、实际日志到局限性展开讨论。</p>

<h2 id="1-如何开启-speedup_by_fusion">1. 如何开启 speedup_by_fusion</h2>

<p><code class="language-plaintext highlighter-rouge">speedup_by_fusion</code> 是 <code class="language-plaintext highlighter-rouge">torch._inductor.config</code> 中的一个配置项。开启后，Inductor 在融合决策阶段会通过实际 benchmark 来判断两个算子融合后是否真的更快，而不仅仅依赖启发式打分。</p>

<p>可以通过以下方式开启：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch._inductor.config</span> <span class="k">as</span> <span class="n">config</span>
<span class="n">config</span><span class="p">.</span><span class="n">benchmark_fusion</span> <span class="o">=</span> <span class="bp">True</span>
</code></pre></div></div>

<p>或者通过 <code class="language-plaintext highlighter-rouge">torch.compile</code> 的 <code class="language-plaintext highlighter-rouge">options</code> 参数传入：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">compiled_model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="s">"inductor"</span><span class="p">,</span> <span class="n">options</span><span class="o">=</span><span class="p">{</span><span class="s">"benchmark_fusion"</span><span class="p">:</span> <span class="bp">True</span><span class="p">})</span>
</code></pre></div></div>

<h2 id="2-开启后做了什么">2. 开启后做了什么</h2>

<p>在默认模式下，Inductor 的融合决策完全基于启发式规则——通过 <code class="language-plaintext highlighter-rouge">can_fuse</code> 检查合法性，通过 <code class="language-plaintext highlighter-rouge">score_fusion</code> 打分排序，然后贪心地执行融合。</p>

<p>开启 <code class="language-plaintext highlighter-rouge">benchmark_fusion</code> 后，流程增加了一个关键步骤：<strong>对候选融合对进行实际 GPU benchmark</strong>。具体来说，系统会分别计时：</p>

<ul>
  <li>两个算子<strong>独立运行</strong>的总耗时</li>
  <li>两个算子<strong>融合后</strong>作为一个 kernel 的耗时</li>
</ul>

<p>只有当融合后确实更快时，才执行该融合。</p>

<h2 id="3-日志中的-speedup-示例">3. 日志中的 Speedup 示例</h2>

<p>开启后，在 Inductor 的 fusion 日志中可以看到类似如下的输出：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>V0312 02:40:20.816000 3795204 scheduler.py:4396] [0/0] [__fusion]
  can fuse (benchmark): fusing OrderedSet(['buf17']) with OrderedSet(['buf18'])
  cause 2.462x speedup
</code></pre></div></div>

<p>这条日志表明 <code class="language-plaintext highlighter-rouge">buf17</code> 和 <code class="language-plaintext highlighter-rouge">buf18</code> 经过实际 benchmark 测试后，融合带来了 <strong>2.462 倍</strong>的加速，因此决定执行融合。</p>

<h2 id="4-局限性与-register-spilling-问题">4. 局限性与 Register Spilling 问题</h2>

<p>开启 <code class="language-plaintext highlighter-rouge">speedup_by_fusion</code> 虽然看起来更加”科学”，但实际使用中存在两个值得讨论的问题。</p>

<h3 id="41-贪心融合的全局最优性问题">4.1 贪心融合的全局最优性问题</h3>

<p>benchmark 测试的是<strong>两个算子</strong>融合前后的性能对比。但这个局部最优并不一定意味着全图在 GPU 上运行时也是最优的。贪心算法的固有缺陷在于：局部最优决策的累积不一定导向全局最优。</p>

<h3 id="42-register-spilling-导致的融合拒绝">4.2 Register Spilling 导致的融合拒绝</h3>

<p>在实际 benchmark 过程中，可能出现融合后的 kernel 因为 <strong>register spilling</strong> 而被拒绝融合的情况。日志示例如下：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>V0312 02:40:31.500000 3795204 scheduler.py:1776] [0/0] [__fusion]
  cannot fuse op1_op6_op11_op2_op7_op12 with op16_op17_op18:
  register spilling of the fused kernel
</code></pre></div></div>

<p><strong>什么是 Register Spilling？</strong> GPU 的每个线程有有限数量的寄存器。当一个 kernel 需要的寄存器数量超出硬件限制时，多余的变量会被”溢出”到较慢的 local memory 中。这就是 register spilling。它会导致显著的性能下降，因为 local memory 的访问延迟远高于寄存器访问。</p>

<p>当前实现中，一旦检测到 register spilling，就<strong>直接拒绝该融合</strong>，不再进一步评估。这带来了一个重要疑问：</p>

<blockquote>
  <p><strong>即使发生了 register spilling，融合带来的 launch overhead 减少是否有可能超过 spilling 的性能损失？</strong></p>
</blockquote>

<p>换句话说，当前的实现可能因为 register spilling 而过于保守地拒绝了一些实际上有益的融合。</p>

<h2 id="5-实验数据">5. 实验数据</h2>

<p>为验证上述假设，我在 <strong>RTX 5090</strong> 上基于一个合成 workload 做了对比实验。实验环境为 PyTorch 2.10.0+cu128。</p>

<ul>
  <li><strong>benchmark_fusion_0</strong>：关闭 <code class="language-plaintext highlighter-rouge">benchmark_fusion</code>（纯启发式）</li>
  <li><strong>benchmark_fusion_1</strong>：开启 <code class="language-plaintext highlighter-rouge">benchmark_fusion</code></li>
</ul>

<h3 id="模型-20hubconflictroundopt">模型 20：HubConflictRoundOpt</h3>

<p>该模型具有共享 hub tensor 和多分支竞争结构，包含多种 reduction 和 transcendental 运算。</p>

<table>
  <thead>
    <tr>
      <th>指标</th>
      <th>关闭 (fusion_0)</th>
      <th>开启 (fusion_1)</th>
      <th>变化</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>编译后运行时间 (ms)</td>
      <td>0.817</td>
      <td>0.964</td>
      <td>+17.9% (变慢)</td>
    </tr>
    <tr>
      <td>Eager 运行时间 (ms)</td>
      <td>79.36</td>
      <td>60.10</td>
      <td>-24.3%</td>
    </tr>
    <tr>
      <td>编译加速比 vs Eager</td>
      <td>97.1x</td>
      <td>62.4x</td>
      <td>-35.8%</td>
    </tr>
    <tr>
      <td>FX 编译耗时 (s)</td>
      <td>7.22</td>
      <td>20.00</td>
      <td>+176.8%</td>
    </tr>
    <tr>
      <td>融合轮数</td>
      <td>3</td>
      <td>2</td>
      <td>-1</td>
    </tr>
    <tr>
      <td>节点缩减数</td>
      <td>67</td>
      <td>62</td>
      <td>-5</td>
    </tr>
    <tr>
      <td>Benchmark 决策次数</td>
      <td>0</td>
      <td>62</td>
      <td>+62</td>
    </tr>
  </tbody>
</table>

<h3 id="数据分析">数据分析</h3>

<p>查看该 workload 的完整日志后可以确认：<strong>所有少融合的节点，都是因为开启 benchmark 后检测到 register spilling 而被拒绝的。</strong></p>

<p>在这个模型上，开启 <code class="language-plaintext highlighter-rouge">benchmark_fusion</code> 后，融合轮数减少、节点缩减数减少，最终编译后运行时间反而<strong>变慢了 17.9%</strong>。这说明在这个 workload 中，<strong>因 register spilling 而少融合节点所带来的额外 launch overhead，很可能比融合后可能出现的 spilling 成本更大。</strong></p>

<p>更值得注意的是，开启 benchmark 后 FX 编译时间从 <strong>7.22s</strong> 增加到 <strong>20.00s</strong>，增幅约 <strong>176.8%</strong>，因为每个候选对都需要实际在 GPU 上跑一遍。</p>

<h2 id="6-思考">6. 思考</h2>

<p>使用真实 benchmark 来决定两个节点是否应该融合，这无疑是一个聪明的做法——它直接用数据说话，避免了启发式规则可能的误判。</p>

<p>但当前对 register spilling 的处理方式过于简单粗暴：<strong>一旦检测到 spilling，直接拒绝融合，不再进行 benchmark 评估。</strong> 即使只看这一个 workload，这种策略也可能过于保守。</p>

<p>个人认为，即使出现了 register spilling，也应该继续运行 benchmark，让实际的运行数据来决定是否融合。毕竟 register spilling 的影响程度取决于溢出量和访问模式，并非所有 spilling 都会导致不可接受的性能下降。</p>

<p>当然，我对 benchmark 的具体实现方式了解有限，也许存在更好的方法来判断融合前后的性能差异。欢迎大家通过邮件与我讨论。</p>]]></content><author><name>Jiakang Huang</name><email>jhuang74@student.ubc.ca</email></author><summary type="html"><![CDATA[详解 PyTorch Inductor 的 speedup_by_fusion 配置：开启方式、工作原理、benchmark 日志示例，以及 register spilling 带来的融合决策争议。]]></summary></entry><entry xml:lang="en"><title type="html">Deep Dive into speedup_by_fusion in PyTorch Inductor</title><link href="https://danielhuangjiakang.github.io/blog/speedup-by-fusion-pytorch-inductor/" rel="alternate" type="text/html" title="Deep Dive into speedup_by_fusion in PyTorch Inductor" /><published>2026-03-31T00:00:00+00:00</published><updated>2026-03-31T00:00:00+00:00</updated><id>https://danielhuangjiakang.github.io/blog/speedup-by-fusion-pytorch-inductor</id><content type="html" xml:base="https://danielhuangjiakang.github.io/blog/speedup-by-fusion-pytorch-inductor/"><![CDATA[<p><strong>Author:</strong> Jiakang Huang</p>

<figure class="post-feature-image">
  <img src="/images/speedup-by-fusion-cover.png" alt="Cover illustration for speedup_by_fusion in PyTorch Inductor" />
</figure>

<p>In the previous post, we walked through the overall architecture of <code class="language-plaintext highlighter-rouge">fuse_nodes</code> in PyTorch Inductor (see: <a href="/blog/fuse-nodes-pytorch-inductor/">Deep Dive into fuse_nodes in PyTorch Inductor</a>). Today we zoom in on a particularly interesting configuration within the fusion pipeline: <code class="language-plaintext highlighter-rouge">speedup_by_fusion</code> (exposed as <code class="language-plaintext highlighter-rouge">benchmark_fusion</code>). We will cover how to enable it, what it does under the hood, what the logs look like, and a critical limitation around register spilling that may lead to suboptimal fusion decisions.</p>

<h2 id="1-enabling-benchmark_fusion">1. Enabling benchmark_fusion</h2>

<p><code class="language-plaintext highlighter-rouge">benchmark_fusion</code> is a config flag in <code class="language-plaintext highlighter-rouge">torch._inductor.config</code>. When turned on, Inductor uses actual GPU benchmarks—rather than heuristics alone—to decide whether fusing two operators is worthwhile.</p>

<p>You can enable it in two ways:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch._inductor.config</span> <span class="k">as</span> <span class="n">config</span>
<span class="n">config</span><span class="p">.</span><span class="n">benchmark_fusion</span> <span class="o">=</span> <span class="bp">True</span>
</code></pre></div></div>

<p>Or via the <code class="language-plaintext highlighter-rouge">options</code> dict passed to <code class="language-plaintext highlighter-rouge">torch.compile</code>:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">compiled_model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="s">"inductor"</span><span class="p">,</span> <span class="n">options</span><span class="o">=</span><span class="p">{</span><span class="s">"benchmark_fusion"</span><span class="p">:</span> <span class="bp">True</span><span class="p">})</span>
</code></pre></div></div>

<h2 id="2-what-happens-when-it-is-enabled">2. What Happens When It Is Enabled</h2>

<p>In default mode, Inductor’s fusion decisions are purely heuristic: <code class="language-plaintext highlighter-rouge">can_fuse</code> checks legality, <code class="language-plaintext highlighter-rouge">score_fusion</code> ranks candidates, and fusions are applied greedily.</p>

<p>With <code class="language-plaintext highlighter-rouge">benchmark_fusion</code> enabled, an additional step is inserted: <strong>each candidate fusion pair is actually benchmarked on the GPU</strong>. The system times:</p>

<ul>
  <li>The <strong>separate execution</strong> of both operators</li>
  <li>The <strong>fused execution</strong> as a single kernel</li>
</ul>

<p>A fusion is only committed if the fused kernel is measurably faster.</p>

<h2 id="3-what-the-logs-look-like">3. What the Logs Look Like</h2>

<p>With benchmark fusion enabled, the Inductor fusion log emits entries like:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>V0312 02:40:20.816000 3795204 scheduler.py:4396] [0/0] [__fusion]
  can fuse (benchmark): fusing OrderedSet(['buf17']) with OrderedSet(['buf18'])
  cause 2.462x speedup
</code></pre></div></div>

<p>This tells us that <code class="language-plaintext highlighter-rouge">buf17</code> and <code class="language-plaintext highlighter-rouge">buf18</code> were actually benchmarked, and the fused kernel ran <strong>2.462x faster</strong>, so the fusion was accepted.</p>

<h2 id="4-limitations-and-the-register-spilling-problem">4. Limitations and the Register Spilling Problem</h2>

<p>While benchmark-driven fusion sounds strictly better than heuristics, there are two issues worth examining.</p>

<h3 id="41-greedy-fusion-is-not-globally-optimal">4.1 Greedy Fusion Is Not Globally Optimal</h3>

<p>The benchmark evaluates a <strong>single pair</strong> of operators in isolation. Even if fusing A and B is locally faster, it does not guarantee that the resulting full graph is globally optimal. This is an inherent limitation of greedy algorithms: a sequence of locally optimal decisions may not compose into a globally optimal solution.</p>

<h3 id="42-register-spilling-causes-premature-rejection">4.2 Register Spilling Causes Premature Rejection</h3>

<p>During benchmarking, the fused kernel may trigger <strong>register spilling</strong>, at which point the current implementation immediately rejects the fusion without measuring the actual performance impact. Here is an example from the logs:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>V0312 02:40:31.500000 3795204 scheduler.py:1776] [0/0] [__fusion]
  cannot fuse op1_op6_op11_op2_op7_op12 with op16_op17_op18:
  register spilling of the fused kernel
</code></pre></div></div>

<p><strong>What is register spilling?</strong> Each GPU thread has a limited number of registers. When a kernel requires more registers than the hardware provides per thread, the excess variables are “spilled” to local memory, which resides in much slower off-chip storage. This increases memory traffic and can degrade performance significantly.</p>

<p>The current implementation treats register spilling as a hard rejection signal. But this raises an important question:</p>

<blockquote>
  <p><strong>Could the reduction in kernel launch overhead from fusion outweigh the performance cost of register spilling?</strong></p>
</blockquote>

<p>In other words, the current policy may be too conservative, rejecting fusions that would still be net beneficial despite some spilling.</p>

<h2 id="5-experimental-results">5. Experimental Results</h2>

<p>To investigate, I ran a controlled experiment on an <strong>RTX 5090</strong> with PyTorch 2.10.0+cu128, comparing two settings on one synthetic workload:</p>

<ul>
  <li><strong>benchmark_fusion_0</strong>: benchmark fusion <strong>off</strong> (heuristic-only)</li>
  <li><strong>benchmark_fusion_1</strong>: benchmark fusion <strong>on</strong></li>
</ul>

<h3 id="model-20-hubconflictroundopt">Model 20: HubConflictRoundOpt</h3>

<p>A synthetic model with a shared hub tensor feeding six competing branches, mixing reductions across different axes and transcendental operations (<code class="language-plaintext highlighter-rouge">tanh</code>, <code class="language-plaintext highlighter-rouge">sin*cos</code>, <code class="language-plaintext highlighter-rouge">relu</code>).</p>

<table>
  <thead>
    <tr>
      <th>Metric</th>
      <th>Off (fusion_0)</th>
      <th>On (fusion_1)</th>
      <th>Change</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Compiled runtime (ms)</td>
      <td>0.817</td>
      <td>0.964</td>
      <td>+17.9% (slower)</td>
    </tr>
    <tr>
      <td>Eager runtime (ms)</td>
      <td>79.36</td>
      <td>60.10</td>
      <td>-24.3%</td>
    </tr>
    <tr>
      <td>Compiled speedup vs Eager</td>
      <td>97.1x</td>
      <td>62.4x</td>
      <td>-35.8%</td>
    </tr>
    <tr>
      <td>FX compile time (s)</td>
      <td>7.22</td>
      <td>20.00</td>
      <td>+176.8%</td>
    </tr>
    <tr>
      <td>Fusion rounds</td>
      <td>3</td>
      <td>2</td>
      <td>-1</td>
    </tr>
    <tr>
      <td>Net node reduction</td>
      <td>67</td>
      <td>62</td>
      <td>-5</td>
    </tr>
    <tr>
      <td>Benchmark decisions</td>
      <td>0</td>
      <td>62</td>
      <td>+62</td>
    </tr>
  </tbody>
</table>

<h3 id="analysis">Analysis</h3>

<p>After reviewing the full logs for this workload, I can confirm that <strong>every fusion rejected in the benchmark-on run was rejected due to register spilling</strong>—not because the benchmark showed a slowdown.</p>

<p>For this model, turning on <code class="language-plaintext highlighter-rouge">benchmark_fusion</code> reduced the number of fusion rounds, reduced net node elimination, and made the compiled runtime <strong>17.9% slower</strong>. That pattern suggests that <strong>the extra launch overhead from keeping more kernels separate (due to spilling-based rejections) outweighed the cost of the spilling that the fused kernels might have incurred.</strong></p>

<p>The compile-time cost is also substantial: FX compile time increased from <strong>7.22s to 20.00s</strong> (<strong>+176.8%</strong>), since each candidate pair has to be compiled and profiled on the GPU.</p>

<h2 id="6-discussion">6. Discussion</h2>

<p>Using real benchmarks to validate fusion decisions is a smart idea—it replaces speculation with measurement. However, the current handling of register spilling is arguably too blunt: <strong>spilling is treated as a hard veto, bypassing the benchmark entirely.</strong></p>

<p>This single workload already suggests that the policy may be overly conservative. A more nuanced approach would be to let the benchmark run even when spilling is detected, and let the actual timing data determine whether the fusion is worthwhile. After all, the severity of register spilling depends heavily on the amount of spilling and memory access patterns—not all spilling leads to unacceptable performance degradation.</p>

<p>I am not fully familiar with the internals of the benchmark implementation, and there may well be better ways to evaluate pre- and post-fusion performance. If you have thoughts or ideas on this topic, I would love to hear from you—feel free to reach out by email.</p>]]></content><author><name>Jiakang Huang</name><email>jhuang74@student.ubc.ca</email></author><summary type="html"><![CDATA[A benchmark-driven analysis of PyTorch Inductor's speedup_by_fusion config, its runtime logs, and why register spilling can reject fusions that still help.]]></summary></entry><entry xml:lang="zh"><title type="html">PyTorch Inductor 中 fuse_nodes 融合流程深度解析</title><link href="https://danielhuangjiakang.github.io/zh/blog/fuse-nodes-pytorch-inductor/" rel="alternate" type="text/html" title="PyTorch Inductor 中 fuse_nodes 融合流程深度解析" /><published>2026-03-29T00:00:00+00:00</published><updated>2026-03-29T00:00:00+00:00</updated><id>https://danielhuangjiakang.github.io/zh/blog/fuse-nodes-pytorch-inductor-CN</id><content type="html" xml:base="https://danielhuangjiakang.github.io/zh/blog/fuse-nodes-pytorch-inductor/"><![CDATA[<p><strong>作者：</strong> Jiakang Huang，Xueyan Zhang</p>

<figure class="post-feature-image">
  <img src="/images/fuse-nodes-pytorch-inductor-cover.png" alt="PyTorch Inductor fuse_nodes 融合流程示意图" />
</figure>

<h2 id="总览">总览</h2>

<p>下图展示了 <code class="language-plaintext highlighter-rouge">fuse_nodes</code> 的完整调用链。整个过程可以概括为一句话：<strong>在节点图上反复寻找可融合的节点对，按优先级打分排序，然后依次尝试真正的融合，直到图不再缩小为止。</strong></p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>fuse_nodes(nodes)
│
└─► fuse_nodes_once()  ×最多10轮，节点数不变或=1时提前停止
    │
    ├─ 1. get_possible_fusions()  ─────────────────── 枚举所有候选融合对
    │   │
    │   ├─ [Loop 1] 按 buffer_name 分组
    │   │   对每个 fusable node，按其读写的 buffer 归入 dict
    │   │
    │   ├─ [Loop 2] 在每个 buffer 组内 check_all_pairs()
    │   │   │  ► 窗口优化：只看前后各64个邻居 → O(64n) 而非 O(n²)
    │   │   │
    │   │   └─► can_fuse(n1, n2)  ──────────────────── 8大类门控检查
    │   │       │  ① 自身判等           ⑤ 顺序/拓扑依赖
    │   │       │  ② 特殊节点拦截       ⑥ 数据类型兼容
    │   │       │  ③ Template快速放行   ⑦ 内存/尺寸约束
    │   │       │  ④ Grouped节点禁入    ⑧ 其他后端限制
    │   │       │
    │   │       └─ 若失败 &amp; node2 是 template/foreach
    │   │          → 反转方向再试 can_fuse(n2, n1)
    │   │            (容器节点可以"吸收"其他节点)
    │   │
    │   ├─ [Loop 3] aggressive_fusion 模式
    │   │   按 node.group 再分一次组，组内再 check_all_pairs()
    │   │
    │   └─► get_possible_fusions_with_highest_priority() ── 去重 &amp; 选优
    │       │
    │       ├─ get_backend(device).get_fusion_pair_priority(n1, n2)
    │       │   后端接口：CPU/CUDA 各自决定融合方式的优先级
    │       │
    │       └─ 同一 pair 可能来自不同分组路径 → 只保留最高优先级的那条
    │
    ├─ 2. score_fusion_key()  ─────────────────────── 对候选对打分排序
    │   │
    │   └─► V.choices.score_fusion()
    │       基于三个维度：
    │         • 融合类型 (template / reduction / ...)
    │         • 预估节省的内存带宽
    │         • 原始图中的拓扑距离（越近越优先）
    │
    └─ 3. _try_fusion_pairs()  ────────────────────── 按排序顺序逐对尝试融合
        排序至关重要：若先融合 (A,B)，则 (B,C) 自动作废
</code></pre></div></div>

<h2 id="阶段一寻找候选对---get_possible_fusions">阶段一：寻找候选对 - <code class="language-plaintext highlighter-rouge">get_possible_fusions</code></h2>

<p>这一步的目标是从整张图中筛出所有“有可能且有价值”被融合的节点对。</p>

<h3 id="buffer-分组---融合的前提">Buffer 分组 - 融合的前提</h3>

<p>代码首先遍历所有 fusable node，按照节点读写的 <code class="language-plaintext highlighter-rouge">buffer_name</code> 建立一个分组字典。背后的直觉很简单：如果两个节点不共享任何 buffer，融合它们大概率没有收益，既不能省掉中间 buffer 的分配，也不能减少内存搬运。因此只在同一个 buffer 组内部做配对检查。</p>

<h3 id="窗口优化---控制搜索空间">窗口优化 - 控制搜索空间</h3>

<p>在每个 buffer 组内调用 <code class="language-plaintext highlighter-rouge">check_all_pairs</code> 做两两配对。这里有一个关键优化：PyTorch 默认只在节点列表的<strong>前后各 64 个邻居</strong>之间检查。对于长度为 <code class="language-plaintext highlighter-rouge">n</code> 的节点列表，候选对数量上界是 <code class="language-plaintext highlighter-rouge">64 * n</code>，而非朴素的 <code class="language-plaintext highlighter-rouge">n^2</code>。这让融合搜索在大型图上依然可控。</p>

<h3 id="can_fuse---8-大类门控"><code class="language-plaintext highlighter-rouge">can_fuse</code> - 8 大类门控</h3>

<p>每一对候选都必须通过 <code class="language-plaintext highlighter-rouge">can_fuse(node1, node2)</code> 的严格审查。检查项至少包括：</p>

<ol>
  <li><strong>自身判等</strong>：<code class="language-plaintext highlighter-rouge">node1 == node2</code>，直接跳过。</li>
  <li><strong>特殊节点拦截</strong>：<code class="language-plaintext highlighter-rouge">FusedMixOrderReductions</code> 等已融合节点不允许再次融合。</li>
  <li><strong>Template 快速放行</strong>：template 节点有专门的短路判定通道。</li>
  <li><strong>Grouped 节点禁入</strong>：<code class="language-plaintext highlighter-rouge">GroupedSchedulerNode</code> 已被分组调度，不再参与融合。</li>
  <li><strong>顺序依赖检查</strong>：最重要的一项，确保融合不会打破数据流的拓扑顺序。</li>
  <li>以及数据类型兼容、内存和尺寸约束、后端限制等更多细粒度校验。</li>
</ol>

<p>一个有趣的细节：如果 <code class="language-plaintext highlighter-rouge">can_fuse(n1, n2)</code> 判定失败，但 <code class="language-plaintext highlighter-rouge">n2</code> 是 template 或 foreach 节点，代码会<strong>反转方向</strong>再试一次 <code class="language-plaintext highlighter-rouge">can_fuse(n2, n1)</code>。原因在于 template 和 foreach 本质上是“容器节点”，它们可以把别的节点“吸收”进来，所以方向不同，融合语义也不同。</p>

<h3 id="激进模式">激进模式</h3>

<p>当 <code class="language-plaintext highlighter-rouge">config.aggressive_fusion</code> 开启时，代码会额外按 <code class="language-plaintext highlighter-rouge">node.group</code> 再做一轮分组。调度器认为同一 group 内的节点属于同一个更大的逻辑单元，值得更积极地尝试融合。</p>

<h2 id="阶段二去重与打分">阶段二：去重与打分</h2>

<h3 id="去重---get_possible_fusions_with_highest_priority">去重 - <code class="language-plaintext highlighter-rouge">get_possible_fusions_with_highest_priority</code></h3>

<p>同一对 <code class="language-plaintext highlighter-rouge">(node1, node2)</code> 可能从不同的分组路径被重复选出，一次来自 buffer 组，一次来自 node group。不同路径意味着不同的融合方式，而我们只需要保留最优的那一种。</p>

<p>去重的核心依据来自后端接口 <code class="language-plaintext highlighter-rouge">get_backend(device).get_fusion_pair_priority(node1, node2)</code>。这是一个动态分派调用，先根据 device 找到对应的后端，例如 CPU 或 CUDA，再调用该后端自己的优先级评估逻辑。基类默认返回 <code class="language-plaintext highlighter-rouge">0</code>，但各后端可以覆写。</p>

<h3 id="打分---score_fusion_key">打分 - <code class="language-plaintext highlighter-rouge">score_fusion_key</code></h3>

<p>去重后的候选对会经过 <code class="language-plaintext highlighter-rouge">V.choices.score_fusion()</code> 打分。打分维度包括：</p>

<ul>
  <li><strong>融合类型</strong>：template 融合、reduction 融合等不同类型权重不同。</li>
  <li><strong>预估节省的内存带宽</strong>：融合后能少搬多少数据，这是最核心的收益指标。</li>
  <li><strong>原始图中的拓扑距离</strong>：距离越近的节点对越优先融合。</li>
</ul>

<p>所有候选对按分数<strong>从高到低排序</strong>，排序结果直接决定融合的先后顺序。</p>

<h2 id="阶段三尝试融合---_try_fusion_pairs">阶段三：尝试融合 - <code class="language-plaintext highlighter-rouge">_try_fusion_pairs</code></h2>

<p>这是真正执行融合的地方。<strong>排序至关重要</strong>：候选对按分数从高到低依次尝试，一旦某个节点已被融合，包含该节点的其他候选对自动作废。</p>

<p>举例来说，假设候选列表中有 <code class="language-plaintext highlighter-rouge">(A, B)</code> 和 <code class="language-plaintext highlighter-rouge">(B, C)</code>，且 <code class="language-plaintext highlighter-rouge">(A, B)</code> 分数更高。那么 <code class="language-plaintext highlighter-rouge">(A, B)</code> 会先被融合，之后 <code class="language-plaintext highlighter-rouge">(B, C)</code> 就不再可行，因为 <code class="language-plaintext highlighter-rouge">B</code> 已经消失在融合节点 <code class="language-plaintext highlighter-rouge">AB</code> 中了。</p>

<p>这种贪心策略加上前面精心设计的打分函数，使得 Inductor 能在合理的时间内找到一个高质量的融合方案。</p>

<h2 id="小结">小结</h2>

<p><code class="language-plaintext highlighter-rouge">fuse_nodes</code> 的设计体现了几个工程上的权衡：</p>

<ul>
  <li><strong>窗口优化</strong>把搜索复杂度从 <code class="language-plaintext highlighter-rouge">O(n^2)</code> 压到接近 <code class="language-plaintext highlighter-rouge">O(n)</code> 的实践表现，让大型图也可行。</li>
  <li><strong>多路分组</strong>通过 buffer 组、node group 和 aggressive 模式，在不同粒度上捕捉融合机会。</li>
  <li><strong>后端分派</strong>让 CPU 和 CUDA 可以各自定义融合偏好。</li>
  <li><strong>贪心排序</strong>用一个简单但有效的策略，在候选对之间做出取舍。</li>
</ul>

<p>整体来看，这是一个“宽搜索到窄筛选到贪心决策”的经典优化流程。</p>]]></content><author><name>Jiakang Huang</name><email>jhuang74@student.ubc.ca</email></author><summary type="html"><![CDATA[系统梳理 PyTorch Inductor 如何在 fuse_nodes 中枚举候选对、执行打分排序，并以贪心策略推进图融合。]]></summary></entry><entry xml:lang="en"><title type="html">Deep Dive into fuse_nodes in PyTorch Inductor</title><link href="https://danielhuangjiakang.github.io/blog/fuse-nodes-pytorch-inductor/" rel="alternate" type="text/html" title="Deep Dive into fuse_nodes in PyTorch Inductor" /><published>2026-03-29T00:00:00+00:00</published><updated>2026-03-29T00:00:00+00:00</updated><id>https://danielhuangjiakang.github.io/blog/fuse-nodes-pytorch-inductor</id><content type="html" xml:base="https://danielhuangjiakang.github.io/blog/fuse-nodes-pytorch-inductor/"><![CDATA[<p><strong>Authors:</strong> Jiakang Huang, Xueyan Zhang</p>

<figure class="post-feature-image">
  <img src="/images/fuse-nodes-pytorch-inductor-cover.png" alt="Cover illustration for the PyTorch Inductor fuse_nodes workflow" />
</figure>

<h2 id="overview">Overview</h2>

<p>The diagram below shows the full call chain of <code class="language-plaintext highlighter-rouge">fuse_nodes</code>. The entire process boils down to one sentence: <strong>repeatedly find fusable node pairs in the graph, score and rank them by priority, then greedily attempt the actual fusions until the graph stops shrinking.</strong></p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>fuse_nodes(nodes)
│
└─► fuse_nodes_once()  ×up to 10 rounds; early exit if size unchanged or =1
    │
    ├─ 1. get_possible_fusions()  ─────────────────── enumerate candidate pairs
    │   │
    │   ├─ [Loop 1] Group nodes by buffer_name
    │   │   For each fusable node, bucket it by the buffers it reads/writes
    │   │
    │   ├─ [Loop 2] check_all_pairs() within each buffer group
    │   │   │  ► Window optimization: only check ±64 neighbors → O(64n) not O(n²)
    │   │   │
    │   │   └─► can_fuse(n1, n2)  ──────────────────── 8 categories of gate checks
    │   │       │  ① Identity check          ⑤ Topological dependency
    │   │       │  ② Special node block       ⑥ Dtype compatibility
    │   │       │  ③ Template fast-path       ⑦ Memory / size constraints
    │   │       │  ④ Grouped node ban         ⑧ Other backend limits
    │   │       │
    │   │       └─ If failed &amp; node2 is template/foreach
    │   │          → retry reversed: can_fuse(n2, n1)
    │   │            (container nodes can "absorb" other nodes)
    │   │
    │   ├─ [Loop 3] aggressive_fusion mode
    │   │   Re-group by node.group, then check_all_pairs() within each group
    │   │
    │   └─► get_possible_fusions_with_highest_priority() ── deduplicate &amp; select
    │       │
    │       ├─ get_backend(device).get_fusion_pair_priority(n1, n2)
    │       │   Backend interface: CPU/CUDA each decide fusion-method priority
    │       │
    │       └─ Same pair may arrive from different grouping paths
    │          → keep only the highest-priority entry
    │
    ├─ 2. score_fusion_key()  ─────────────────────── score &amp; sort candidates
    │   │
    │   └─► V.choices.score_fusion()
    │       Based on three dimensions:
    │         • Fusion type (template / reduction / ...)
    │         • Estimated memory bandwidth saved
    │         • Topological distance in the original graph (closer = better)
    │
    └─ 3. _try_fusion_pairs()  ────────────────────── attempt fusions in rank order
        Order is critical: fusing (A,B) first invalidates (B,C)
</code></pre></div></div>

<h2 id="phase-1-finding-candidate-pairs---get_possible_fusions">Phase 1: Finding Candidate Pairs - <code class="language-plaintext highlighter-rouge">get_possible_fusions</code></h2>

<p>The goal here is to sift through the entire graph and produce every node pair that is both <em>possible</em> and <em>worthwhile</em> to fuse.</p>

<h3 id="buffer-grouping---the-prerequisite-for-fusion">Buffer Grouping - The Prerequisite for Fusion</h3>

<p>The code first iterates over all fusable nodes and buckets each one by the <code class="language-plaintext highlighter-rouge">buffer_name</code> values it reads or writes, building a grouping dictionary. The intuition is straightforward: if two nodes share no buffers, fusing them is unlikely to yield any benefit because there is no intermediate buffer to eliminate and no memory traffic to save. So pair-checking is restricted to nodes within the same buffer group.</p>

<h3 id="window-optimization---taming-the-search-space">Window Optimization - Taming the Search Space</h3>

<p>Within each buffer group, <code class="language-plaintext highlighter-rouge">check_all_pairs</code> enumerates pairwise candidates. A key optimization keeps this tractable: PyTorch only checks nodes within a <strong>window of plus or minus 64 neighbors</strong> in the node list. For a list of length <code class="language-plaintext highlighter-rouge">n</code>, this caps the number of candidate pairs at <code class="language-plaintext highlighter-rouge">64 * n</code> rather than the naive <code class="language-plaintext highlighter-rouge">n^2</code>. This makes the fusion search feasible even on very large graphs.</p>

<h3 id="can_fuse---eight-categories-of-gate-checks"><code class="language-plaintext highlighter-rouge">can_fuse</code> - Eight Categories of Gate Checks</h3>

<p>Every candidate pair must survive the gauntlet of <code class="language-plaintext highlighter-rouge">can_fuse(node1, node2)</code>. The checks include at least:</p>

<ol>
  <li><strong>Identity</strong>: <code class="language-plaintext highlighter-rouge">node1 == node2</code> so the pair is skipped immediately.</li>
  <li><strong>Special node block</strong>: Nodes like <code class="language-plaintext highlighter-rouge">FusedMixOrderReductions</code> that have already been fused cannot fuse again.</li>
  <li><strong>Template fast-path</strong>: Template nodes have a dedicated short-circuit that can approve fusion quickly.</li>
  <li><strong>Grouped node ban</strong>: <code class="language-plaintext highlighter-rouge">GroupedSchedulerNode</code> instances are already group-scheduled and barred from further fusion.</li>
  <li><strong>Topological dependency</strong>: The most critical check, ensuring fusion will not violate data-flow ordering.</li>
  <li>Dtype compatibility, memory and size constraints, backend-specific limits, and other implementation guards.</li>
</ol>

<p>An interesting detail: if <code class="language-plaintext highlighter-rouge">can_fuse(n1, n2)</code> fails but <code class="language-plaintext highlighter-rouge">n2</code> is a <strong>template or foreach node</strong>, the code retries in the <strong>reversed direction</strong> with <code class="language-plaintext highlighter-rouge">can_fuse(n2, n1)</code>. The reason is that template and foreach nodes are effectively container nodes that can absorb other nodes into themselves, so the fusion direction matters.</p>

<h3 id="aggressive-mode">Aggressive Mode</h3>

<p>When <code class="language-plaintext highlighter-rouge">config.aggressive_fusion</code> is enabled, an additional grouping pass runs based on <code class="language-plaintext highlighter-rouge">node.group</code>. The scheduler considers nodes in the same group to be part of a larger logical unit, making them prime candidates for more aggressive fusion attempts.</p>

<h2 id="phase-2-deduplication-and-scoring">Phase 2: Deduplication and Scoring</h2>

<h3 id="deduplication---get_possible_fusions_with_highest_priority">Deduplication - <code class="language-plaintext highlighter-rouge">get_possible_fusions_with_highest_priority</code></h3>

<p>The same pair <code class="language-plaintext highlighter-rouge">(node1, node2)</code> may be discovered through different grouping paths, once from a buffer group and once from a node group. Different paths may imply different fusion strategies, but we only want the best one.</p>

<p>The arbiter is the backend interface <code class="language-plaintext highlighter-rouge">get_backend(device).get_fusion_pair_priority(node1, node2)</code>. This is dynamic dispatch: the code first resolves the backend for the current device, such as CPU or CUDA, and then asks that backend to evaluate the pair priority. The base class returns <code class="language-plaintext highlighter-rouge">0</code> by default, but each backend is free to override this.</p>

<h3 id="scoring---score_fusion_key">Scoring - <code class="language-plaintext highlighter-rouge">score_fusion_key</code></h3>

<p>After deduplication, each remaining candidate pair is scored via <code class="language-plaintext highlighter-rouge">V.choices.score_fusion()</code>. The scoring dimensions are:</p>

<ul>
  <li><strong>Fusion type</strong>: Template fusions, reduction fusions, and other categories carry different weights.</li>
  <li><strong>Estimated memory bandwidth saved</strong>: The core payoff metric, measuring how much data movement can be eliminated.</li>
  <li><strong>Topological distance in the original graph</strong>: Closer pairs are preferred.</li>
</ul>

<p>All candidates are sorted from highest to lowest score. That ordering directly determines the sequence in which fusions are attempted.</p>

<h2 id="phase-3-attempting-fusions---_try_fusion_pairs">Phase 3: Attempting Fusions - <code class="language-plaintext highlighter-rouge">_try_fusion_pairs</code></h2>

<p>This is where fusions actually happen. <strong>The sorted order is paramount</strong>: candidates are tried from highest score to lowest, and once a node has been consumed by a fusion, any other candidate pair involving that node is automatically invalidated.</p>

<p>For example, suppose the candidate list contains <code class="language-plaintext highlighter-rouge">(A, B)</code> and <code class="language-plaintext highlighter-rouge">(B, C)</code>, with <code class="language-plaintext highlighter-rouge">(A, B)</code> scoring higher. <code class="language-plaintext highlighter-rouge">(A, B)</code> will be fused first, after which <code class="language-plaintext highlighter-rouge">(B, C)</code> becomes infeasible because <code class="language-plaintext highlighter-rouge">B</code> has been absorbed into the fused node <code class="language-plaintext highlighter-rouge">AB</code>.</p>

<p>This greedy strategy, combined with the carefully designed scoring function, allows Inductor to find a high-quality fusion plan in reasonable time.</p>

<h2 id="takeaways">Takeaways</h2>

<p>The design of <code class="language-plaintext highlighter-rouge">fuse_nodes</code> reflects several engineering trade-offs:</p>

<ul>
  <li><strong>Window optimization</strong> reduces search complexity from <code class="language-plaintext highlighter-rouge">O(n^2)</code> to <code class="language-plaintext highlighter-rouge">O(n)</code> behavior in practice, keeping large graphs tractable.</li>
  <li><strong>Multi-path grouping</strong> with buffer groups, node groups, and aggressive mode captures fusion opportunities at different granularities.</li>
  <li><strong>Backend dispatch</strong> lets CPU and CUDA define their own fusion preferences independently.</li>
  <li><strong>Greedy ordering</strong> uses a simple but effective strategy to arbitrate between competing candidate pairs.</li>
</ul>

<p>At a high level, this is a classic optimization pipeline: <strong>broad search to narrow filtering to greedy decision-making</strong>.</p>]]></content><author><name>Jiakang Huang</name><email>jhuang74@student.ubc.ca</email></author><summary type="html"><![CDATA[A structured walkthrough of how PyTorch Inductor enumerates, scores, and greedily applies graph fusion candidates in fuse_nodes.]]></summary></entry></feed>