Introduction Link to heading

The original claim of OpenAI Triton is:

We’re releasing Triton 1.0, an open-source Python-like programming language which enables researchers with no CUDA experience to write highly efficient GPU code — most of the time on par with what an expert would be able to produce.

The core ideas of Triton:

  • Program GPU with Python: so that the effort to program GPU device is minimal;
  • Encapsulate optimizations within SM: user should focus on partitioning the job and schedule them on SMs, Triton is then responsible for optimizing the code within SM automatically;

How did Triton implement it?

Triton turns Python code to Triton IR on-the-fly through Python AST, then Triton optimizes and lowers it to LLVM-IR / MLIR, followed by generating PTX directly through libLLVM, and compile to cubin through ptxas.

Did Triton manage to achieve its goal?

  • Program GPU with Python: pretty successful, we can write a functional CUDA kernel with Triton quite fast;
  • Encapsulate optimizations within SM: I would say Triton only achieved 60% of this goal. User usually can’t generate an efficient kernel directly, as one concrete example, we wrote a very first version GroupNorm kernel using Triton, which is functionally correct, but not necessarily efficient without knowing what Triton did to the code;

What’re the advantages of Triton?

  • Easy to use: requires far less time to write a kernel compared to CUDA;
  • Performant: can generate kernel with comparable performance as skilled CUDA programmer, the prerequisite is user knows how to tune Triton code;

What’re the disadvantages of Triton?

  • Hard to debug: the whole optimization is black-box, user has to go through PTX/IR to understand the issues;
  • Limitations: some kernel can’t be implemented with Triton due to limitations, e.g. tile size must be power of 2, doesn’t support slice;

What can we do with Triton?

  • Fast prototyping: we can experiment various ideas with Triton and iterate very fast, we can implement the idea with CUDA at last if nailed;
  • New kernels/operators: we can use Triton to generate kernels/operators not haven’t been implemented in cuBLAS/cuDNN/frameworks;

Compared to the manual written CUDA kernel, where do the performance benefits of Triton mainly come from?

  • Triton can automate a bunch of optimizations, which CUDA developer may not aware of;
  • Kernels from library like cuBLAS has universality requirement, Triton can drop it on-demand and generate simpler code, e.g. bias addition in cuBLAS GEMM;
  • Triton kernels can be auto-tuned;
  • LLVM generates better PTX than NVCC sometimes, e.g. loop unroll;

What inspiration can we draw from Triton?

  • Optimization can be abstracted/encapsulated, or even automated: this can accelerate the CUDA kernel development flow and save the life of a lot developers;
  • Rapid iteration is important: developer with minimal skills can write fairly good CUDA kernel after a couple of iterations;

Basic Link to heading

The goal of this experiment is to build the basic mapping from Triton to CUDA, so that Triton novice can get a sense of what kind of underlying CUDA code he is manipulating. This is important for us to utilize Triton for fast prototyping while ensuring the performance is reasonable, since it’s not easy to debug the performance of Triton generated kernels. Due to the complexities of Triton automatic optimization pipeline, we won’t expect the discoveries to apply to all of the cases. We only expect it covers 80% of the cases to make it simple.

Setup:

  • NVIDIA 23.03 PyTorch contaienr;
  • Triton d54c04a;
  • GH100-700W;
  • Measured on 04/23/2023;

As a start, we use a simple copy kernel to show basic Triton to CUDA mappings.

import torch

import triton
import triton.language as tl

torch.manual_seed(0)

@triton.jit
def kernel_230423_01(
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    idx = tl.program_id(0)
    offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x, mask=mask)


def launch_kernel_230423_01():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')
    grid = triton.cdiv(n, BLOCK_SIZE_N)

    kernel_230423_01[(grid, )](x, y, n, BLOCK_SIZE_N)
    assert torch.allclose(x, y)

Reverse engineer the generated PTX, we get its corresponding CUDA code:

__global__ void kernel_230423_01(float *x_ptr, float *y_ptr, int n) {
  const int BLOCK_SIZE_N = 128;
  int lane = threadIdx.x & (BLOCK_SIZE_N - 1);
  int gid = (blockIdx.x << 7) | lane;  // LOG(128) = 7

  if (gid < n) {
    y_ptr[gid] = x_ptr[gid];
  }
}

It’s a very simple example, but we can derive following conclusions:

  • (grid,) maps to (gridDim.x,);

  • Tensor x casted to pointer x_ptr implicitly;

  • BLOCK_SIZE_N, whose type is tl.constexpr, embedded in the generated code through constant folding and propagation;

  • tl.program_id(0) maps to blockIdx.x;

  • Threads assigned to tl.arange(0, BLOCK_SIZE_N), it directly maps to threadIdx.x in this case, since default number of warps is 4, and BLOCK_SIZE_N is 128;

  • mask maps to the if condition, and finally turned to predicate;

  • Memory load and store with a block of pointers, i.e. tl.load() and tl.store(), distributed and scheduled to different threads. In this example, each thread load and store a 32-bits value:

    @%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
    @%p1 st.global.b32 [ %rd2 + 0 ], { %r2 };
    

Scheduling Link to heading

Re-launch the kernel with fewer warps:

def launch_kernel_230423_02():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')
    grid = triton.cdiv(n, BLOCK_SIZE_N)

    kernel_230423_01[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=2)
    assert torch.allclose(x, y)

Observations:

  • The compiled kernel kernel_230423_01 still launched with 128 threads, changing num_warps to a compiled function doesn’t invoke re-compilation and doesn’t take any effect. So num_warps is a captured constant during the first compilation. This is different from CUDA, where we can change the number of blocks and threads when launching kernel;

Copy kernel_230423_01 to kernel_230423_03a and kernel_230423_03b, launch with 1 warp for kernel_230423_03a and and 16 warps for kernel_230423_03b to see what happens if we have less or more threads:

@triton.jit
def kernel_230423_03a(  # same as kernel_230423_01
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    idx = tl.program_id(0)
    offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x, mask=mask)

@triton.jit
def kernel_230423_03b(  # same as kernel_230423_01
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    idx = tl.program_id(0)
    offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_03():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')
    grid = triton.cdiv(n, BLOCK_SIZE_N)

    kernel_230423_03a[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x, y)
    y.fill_(0)
    kernel_230423_03b[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=16)
    assert torch.allclose(x, y)

Reverse engineer the PTX and we get the CUDA code for them:

__global__ void kernel_230423_03a(int *x_ptr, int *y_ptr, int n) {
  int lane = (threadIdx.x << 2) & 124;
  int gid_base = (blockIdx.x << 7) | lane;

  if (gid_base < n) {
#pragma unroll(4)
    for (int i = 0; i < 4; i++) {
      if (gid_base + i < n) {
        y_ptr[gid_base + i] = x_ptr[gid_base + i];
      }
    }
  }
}

__global__ void kernel_230423_03b(int *x_ptr, int *y_ptr, int n) {
  int lane = threadIdx.x & 127;
  int gid = (blockIdx.x << 7) | lane;

  if (gid < n) {
    y_ptr[gid] = x_ptr[gid];
  }
}

Observations:

  • If tile size larger than thread count, Triton will let each thread process multiple elements. In this specific case Triton unroll the code;
  • If tile size smaller than thread count, redundant threads will process the same data, which are wasted. Thus user has to take care of the relation between number of warps used and the tile size;

However, it’s worth to notice that Triton uses LLVM which generates different PTX than what NVCC generates, and it’s obviously the PTX generated by LLVM is better than NVCC (the result is subject to change depending on different version of NVCC):

// Generated by Triton and LLVM                          |  // Generated by NVCC
.visible .entry kernel_03a_0d1d2(                        |  .visible .entry _Z10kernel_03aPiS_i(
    .param .u64 kernel_03a_0d1d2_param_0,                |      .param .u64 _Z10kernel_03aPiS_i_param_0,
    .param .u64 kernel_03a_0d1d2_param_1,                |      .param .u64 _Z10kernel_03aPiS_i_param_1,
    .param .u32 kernel_03a_0d1d2_param_2                 |      .param .u32 _Z10kernel_03aPiS_i_param_2
)                                                        |  )
.maxntid 32, 1, 1                                        |  {
{                                                        |      .reg .pred  %p<5>;
    .reg .pred     %p<9>;                                |      .reg .b32   %r<15>;
    .reg .b32     %r<19>;                                |      .reg .b64   %rd<8>;
    .reg .b64     %rd<12>;                               |
                                                         |      ld.param.u64    %rd3, [_Z10kernel_03aPiS_i_param_0];
    ld.param.u64     %rd9, [kernel_03a_0d1d2_param_0];   |      ld.param.u64    %rd4, [_Z10kernel_03aPiS_i_param_1];
    ld.param.u64     %rd10, [kernel_03a_0d1d2_param_1];  |      ld.param.u32    %r2, [_Z10kernel_03aPiS_i_param_2];
    mov.u32     %r9, %tid.x;                             |      mov.u32     %r3, %tid.x;
    shl.b32     %r10, %r9, 2;                            |      shl.b32     %r4, %r3, 2;
    ld.param.u32     %r11, [kernel_03a_0d1d2_param_2];   |      and.b32     %r5, %r4, 124;
    and.b32      %r12, %r10, 124;                        |      mov.u32     %r6, %ctaid.x;
    mov.u32     %r13, %ctaid.x;                          |      shl.b32     %r7, %r6, 7;
    shl.b32     %r14, %r13, 7;                           |      or.b32      %r1, %r5, %r7;
    or.b32      %r15, %r12, %r14;                        |      setp.ge.s32     %p1, %r1, %r2;
    or.b32      %r16, %r15, 1;                           |      @%p1 bra    $L__BB0_7;
    or.b32      %r17, %r15, 2;                           |
    or.b32      %r18, %r15, 3;                           |      cvta.to.global.u64  %rd5, %rd3;
    setp.lt.s32     %p1, %r15, %r11;                     |      mul.wide.s32    %rd6, %r1, 4;
    setp.lt.s32     %p2, %r16, %r11;                     |      add.s64     %rd1, %rd5, %rd6;
    setp.lt.s32     %p3, %r17, %r11;                     |      ld.global.u32   %r8, [%rd1];
    setp.lt.s32     %p4, %r18, %r11;                     |      cvta.to.global.u64  %rd7, %rd4;
    mul.wide.s32     %rd11, %r15, 4;                     |      add.s64     %rd2, %rd7, %rd6;
    add.s64     %rd1, %rd9, %rd11;                       |      st.global.u32   [%rd2], %r8;
    add.s64     %rd2, %rd1, 4;                           |      add.s32     %r9, %r1, 1;
    add.s64     %rd3, %rd1, 8;                           |      setp.ge.s32     %p2, %r9, %r2;
    add.s64     %rd4, %rd1, 12;                          |      @%p2 bra    $L__BB0_3;
    @%p1 ld.global.b32 { %r5 }, [ %rd1 + 0 ];            |
    @%p2 ld.global.b32 { %r6 }, [ %rd2 + 0 ];            |      ld.global.u32   %r10, [%rd1+4];
    @%p3 ld.global.b32 { %r7 }, [ %rd3 + 0 ];            |      st.global.u32   [%rd2+4], %r10;
    @%p4 ld.global.b32 { %r8 }, [ %rd4 + 0 ];            |
    add.s64     %rd5, %rd10, %rd11;                      |  $L__BB0_3:
    add.s64     %rd6, %rd5, 4;                           |      add.s32     %r11, %r1, 2;
    add.s64     %rd7, %rd5, 8;                           |      setp.ge.s32     %p3, %r11, %r2;
    add.s64     %rd8, %rd5, 12;                          |      @%p3 bra    $L__BB0_5;
    @%p1 st.global.b32 [ %rd5 + 0 ], { %r5 };            |
    @%p2 st.global.b32 [ %rd6 + 0 ], { %r6 };            |      ld.global.u32   %r12, [%rd1+8];
    @%p3 st.global.b32 [ %rd7 + 0 ], { %r7 };            |      st.global.u32   [%rd2+8], %r12;
    @%p4 st.global.b32 [ %rd8 + 0 ], { %r8 };            |
    ret;                                                 |  $L__BB0_5:
                                                         |      add.s32     %r13, %r1, 3;
}                                                        |      setp.ge.s32     %p4, %r13, %r2;
                                                         |      @%p4 bra    $L__BB0_7;
                                                         |
                                                         |      ld.global.u32   %r14, [%rd1+12];
                                                         |      st.global.u32   [%rd2+12], %r14;
                                                         |
                                                         |  $L__BB0_7:
                                                         |      ret;
                                                         |  }

Consider an example where we use a for loop to iterate each tile, and launched with more threads than tile size:

@triton.jit
def kernel_230424_01(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
    BLOCK_SIZE_N: tl.constexpr,
):
    tiles = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
    for i in range(tiles):
        offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        mask = offsets < N
        x = tl.load(x_ptr + offsets, mask=mask)
        tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230424_01():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')

    kernel_230424_01[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=16)
    assert torch.allclose(x, y)

Reverse engineer the generated PTX code, we get following CUDA kernel:

__global__ void kernel_230424_01(int *input, int *output) {
  int tid = threadIdx.x;
  int lane = tid & 127;

#pragma unroll(8)
  for (int i = 0; i < 7; ++i) {
    output[lane + i * 128] = input[lane + i * 128];
  }

  if ((lane | (7 * 128)) < 1000) {
    output[lane + 7 * 128] = input[lane + 7 * 128];
  }
}

We launched 1 CTA, each with 16 warps (512 threads), tile size is 128. However, as we can see from above equivalent CUDA code, Triton doesn’t schedule threads to parallelize the for loop, it unroll the for loop instead. The extra threads are wasted and only 128 thread utilized. This is counterintuitive and should be taken carefully.

What happens if we make it as 2D tile?

@triton.jit
def kernel_230424_02(
    x_ptr,
    y_ptr,
    M: tl.constexpr,  # static shape
    N: tl.constexpr,  # static shape
    BLOCK_SIZE_M: tl.constexpr,
):
    rows = tl.arange(0, BLOCK_SIZE_M)
    cols = tl.arange(0, N)
    tiles = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
    for i in range(tiles):
        # 2D tile, mask
        offsets = i * BLOCK_SIZE_M * N + N * rows[:, None] + cols[None, :]
        mask = (i * BLOCK_SIZE_M + rows < M)[:, None]
        x = tl.load(x_ptr + offsets, mask=mask)
        tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230424_02():
    m, n = 1000, 32
    BLOCK_SIZE_M = 128
    x = torch.randn((m, n), device='cuda')
    y = torch.zeros((m, n), device='cuda')

    kernel_230424_02[(1, )](x, y, m, n, BLOCK_SIZE_M, num_warps=16)
    assert torch.allclose(x, y)

Reverse engineer the PTX and we got following CUDA code:

__global__ void kernel_230424_02(int *x_ptr, int *y_ptr) {
  int tid_x = threadIdx.x;
  int lane_row = (tid_x >> 3) & 127;
  int lane_col = (tid_x << 2) & 28;
  int offset = (lane_row << 5) | lane_col;

  x_ptr += offset;
  y_ptr += offset;

#pragma unroll(7)
  for (int i = 0; i < 7; ++i) {
    int *ptr1 = x_ptr + i * 4096;
    int *ptr2 = y_ptr + i * 4096;
    uint4 var1 = *((uint4 *)ptr1);
    uint4 var2 = *((uint4 *)(ptr1 + 2048));

    *((uint4 *)ptr2) = var1;
    *((uint4 *)(ptr2 + 2048)) = var2;
  }

  if ((lane_row | 896) < 1000) {
    int *ptr1 = x_ptr + 28672;
    int *ptr2 = y_ptr + 28672;
    uint4 temp = *((uint4 *)ptr1);
    *((uint4 *)ptr2) = temp;
  }

  if (lane_row < 40) {
    int *ptr1 = x_ptr + 30720;
    int *ptr2 = y_ptr + 30720;
    uint4 temp = *((uint4 *)ptr1);
    *((uint4 *)ptr2) = temp;
  }
}

As you can see, the problem size is 1000x32, tile size is 128x32, we launched 16 warps (512 threads). Triton deduce that each thread need to process 128x32/512=8 elements, so Triton vectorizes it with ld.global.v4.b32 and unroll for 2 times. Triton also takes care of the last tile correctly in this case.

Observations:

  • Triton doesn’t schedule threads to parallelize the for loop, every iteration of a for loop wrote by user will be executed by all threads;
  • Triton deduces elements need to be processed per thread, and may potentially vectorize and/or unroll if tile size is larger than the number of threads per CTA;
  • If tile size is smaller than the number of threads per CTA, Triton won’t parallelize the redundant threads across the for loop. User has to re-organize the tile with higher dimension so that the redundant threads can be fully utilized;

Vectorization Link to heading

As we mentioned earlier, Triton will generate unrolled code if each thread process multiple elements. When will Triton vectorize the code?

Let’s start by making kernel_230423_04 as static shape:

@triton.jit
def kernel_230423_04(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
    BLOCK_SIZE_N: tl.constexpr,
):
    idx = tl.program_id(0)
    offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < N
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_04():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')
    grid = triton.cdiv(n, BLOCK_SIZE_N)

    kernel_230423_04[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x, y)

The generated PTX contains vectorized load/store:

add.s64     %rd1, %rd3, %rd5;
@%p1 ld.global.v4.b32 { %r5, %r6, %r7, %r8 }, [ %rd1 + 0 ];
add.s64     %rd2, %rd4, %rd5;
@%p1 st.global.v4.b32 [ %rd2 + 0 ], { %r5, %r6, %r7, %r8 };

Its corresponding CUDA code is:

__global__ void kernel_230423_04(int4 *x_ptr, int4 *y_ptr) {
  int lane = (threadIdx.x << 2) & 124;
  int gid = (blockIdx.x << 7) | lane;
  int idx = gid >> 2;

  if (gid < 1000) {
    y_ptr[idx] = x_ptr[idx];
  }
}

Observations:

  • Triton can generate vectorized code with static shape. The width of vectorization depends on which multiple of N is. For fp32, if N is multiple of 4 and 2, then it will generate ldg.128 and ldg.64. But if N is not multiple of 2, e.g. 1001, Triton can’t generate vectorized code.

What if it’s dynamic shape? Let’s try with kernel_230423_05, who has a single CTA to iterate all tiles:

@triton.jit
def kernel_230423_05(
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
    for i in range(tiles):
        offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        mask = offsets < n
        x = tl.load(x_ptr + offsets, mask=mask)
        tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_05():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')

    kernel_230423_05[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x, y)

The generated PTX contains unrolled load/store:

add.s64     %rd3, %rd1, %rd11;
add.s64     %rd4, %rd3, 4;
add.s64     %rd5, %rd3, 8;
add.s64     %rd6, %rd3, 12;
@%p2 ld.global.b32 { %r15 }, [ %rd3 + 0 ];
@%p3 ld.global.b32 { %r16 }, [ %rd4 + 0 ];
@%p4 ld.global.b32 { %r17 }, [ %rd5 + 0 ];
@%p5 ld.global.b32 { %r18 }, [ %rd6 + 0 ];
add.s64     %rd7, %rd2, %rd11;
add.s64     %rd8, %rd7, 4;
add.s64     %rd9, %rd7, 8;
add.s64     %rd10, %rd7, 12;
@%p2 st.global.b32 [ %rd7 + 0 ], { %r15 };
@%p3 st.global.b32 [ %rd8 + 0 ], { %r16 };
@%p4 st.global.b32 [ %rd9 + 0 ], { %r17 };
@%p5 st.global.b32 [ %rd10 + 0 ], { %r18 };

Its corresponding CUDA code is:

__global__ void kernel_230423_05(int *x_ptr, int *y_ptr, int n) {
  int lane = (threadIdx.x << 2) & 124;
  int m = n + 127;
  if (m >= 128) {
    int i = 0, blockIdx = 0;
    do {
      int gid = blockIdx | lane;
      int gid1 = gid + 1;
      int gid2 = gid + 2;
      int gid3 = gid + 3;

      if (gid < n) {
        y_ptr[gid] = x_ptr[gid];
      }
      if (gid1 < n) {
        y_ptr[gid1] = x_ptr[gid1];
      }
      if (gid2 < n) {
        y_ptr[gid2] = x_ptr[gid2];
      }
      if (gid3 < n) {
        y_ptr[gid3] = x_ptr[gid3];
      }

      i += 1;
      blockIdx += 128;
    } while (i < ((m + ((m >> 31) >> 25)) >> 7));
  }
}

Even if we give compiler hint n = tl.multiple_of(n, 4), it still can’t generate vectorized code. Let’s try to make the number of elements as multiple of 16:

@triton.jit
def kernel_230423_06(  # same as kernel_230423_05
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
    for i in range(tiles):
        offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        mask = offsets < n
        x = tl.load(x_ptr + offsets, mask=mask)
        tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_06():
    n = 1008  # make it multiple of 16
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')

    kernel_230423_06[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x, y)

Then the generated PTX is vectorized:

add.s64     %rd3, %rd1, %rd5;
@%p2 ld.global.v4.b32 { %r15, %r16, %r17, %r18 }, [ %rd3 + 0 ];
add.s64     %rd4, %rd2, %rd5;
@%p2 st.global.v4.b32 [ %rd4 + 0 ], { %r15, %r16, %r17, %r18 };

Another way is to specialize the last tile, since we know for sure that all tiles except the last tile won’t out-of-bounds. So we can drop mask for them and then Triton won’t consider if the boundary will break inside a vectorized instruction:

@triton.jit
def kernel_230423_07(
    x_ptr,
    y_ptr,
    n,
    BLOCK_SIZE_N: tl.constexpr,
):
    tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
    # We know for sure tiles except for the last tile won't out-of-bounds.
    # Drop mask so that Triton can vectorize it, i.e. use ldg.128.
    for i in range(tiles - 1):
        offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        x = tl.load(x_ptr + offsets)
        tl.store(y_ptr + offsets, x)
    # Last tile, aware of mask, no vectorization.
    offsets = (tiles - 1) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_07():
    n = 1000
    BLOCK_SIZE_N = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((n, ), device='cuda')

    kernel_230423_07[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x, y)

Triton generates vectorized code for all previous tiles, and generate unrolled code for last tile:

add.s64     %rd3, %rd1, %rd5;
@%p2 ld.global.v4.b32 { %r16, %r17, %r18, %r19 }, [ %rd3 + 0 ];
add.s64     %rd4, %rd2, %rd5;
@%p2 st.global.v4.b32 [ %rd4 + 0 ], { %r16, %r17, %r18, %r19 };
...
...
add.s64     %rd6, %rd1, %rd14;
add.s64     %rd7, %rd6, 4;
add.s64     %rd8, %rd6, 8;
add.s64     %rd9, %rd6, 12;
@%p5 ld.global.b32 { %r29 }, [ %rd6 + 0 ];
@%p6 ld.global.b32 { %r30 }, [ %rd7 + 0 ];
@%p7 ld.global.b32 { %r31 }, [ %rd8 + 0 ];
@%p8 ld.global.b32 { %r32 }, [ %rd9 + 0 ];
add.s64     %rd10, %rd2, %rd14;
add.s64     %rd11, %rd10, 4;
add.s64     %rd12, %rd10, 8;
add.s64     %rd13, %rd10, 12;
@%p5 st.global.b32 [ %rd10 + 0 ], { %r29 };
@%p6 st.global.b32 [ %rd11 + 0 ], { %r30 };
@%p7 st.global.b32 [ %rd12 + 0 ], { %r31 };
@%p8 st.global.b32 [ %rd13 + 0 ], { %r32 };

Its corresponding CUDA code is:

__global__ void kernel_230423_07(int *x_ptr, int *y_ptr, int n) {
  int lane = (threadIdx.x << 2) & 124;
  int m = n + 127;
  int num_blocks = (m + ((m >> 31) >> 25)) >> 7;

  if (m >= 256) {
    int i = 0, blockIdx = 0;
    while (i < num_blocks - 1) {
      int gid = blockIdx | lane;
      int idx = gid >> 2;
      int4 temp = reinterpret_cast<int4 *>(x_ptr)[idx];
      reinterpret_cast<int4 *>(y_ptr)[idx] = temp;

      i += 1;
      blockIdx += 128;
    }
  }

  int gid_last_block = ((num_blocks - 1) << 7) | lane;
  int gid1 = gid_last_block + 1;
  int gid2 = gid_last_block + 2;
  int gid3 = gid_last_block + 3;

  if (gid_last_block < n) {
    y_ptr[gid_last_block] = x_ptr[gid_last_block];
  }
  if (gid1 < n) {
    y_ptr[gid1] = x_ptr[gid1];
  }
  if (gid2 < n) {
    y_ptr[gid2] = x_ptr[gid2];
  }
  if (gid3 < n) {
    y_ptr[gid3] = x_ptr[gid3];
  }
}

Conclusions: Triton can generate vectorized code only if Triton knows for sure the vectorized memory access per thread (e.g. ldg.128, ldg.64) won’t traverse legal memory boundary. Triton can utilize vectorization in following situations:

  • The input is static shape, and it’s at least multiple of 2;
  • The input is dynamic shape, and the shape is multiple of 16;
  • The input is dynamic shape, and user specialize the last tile;

Reduction Link to heading

@triton.jit
def kernel_230424_03(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
):
    cols = tl.arange(0, N)
    mask = cols < N
    x = tl.load(x_ptr + cols, mask=mask, other=0.0)
    total = tl.sum(x, axis=0)
    tl.store(y_ptr, total)

def launch_kernel_230424_03():
    n = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((1, ), device='cuda')

    kernel_230424_03[(1, )](x, y, n, num_warps=4)
    assert torch.allclose(x.sum(), y)

Its corresponding CUDA code is:

__global__ void kernel_230424_03(float *input, float *output) {
  __shared__ float global_smem[128];

  int tid = threadIdx.x;
  int idx = tid & 127;
  int lane = tid & 31;
  int bank = (tid >> 3) & 0x1FFFFFFC;

  float res = input[idx];

#pragma unroll
  for (unsigned N = 16; N > 0; N >>= 1) {
    res += __shfl_xor_sync(0xFFFFFFFF, res, N);
  }
  if (lane == 0) {
    global_smem[bank] = res;
  }
  __syncthreads();

  float shared_val = global_smem[tid];
#pragma unroll
  for (unsigned N = 2; N > 0; N >>= 1) {
    shared_val += __shfl_xor_sync(0xFFFFFFFF, shared_val, N);
  }
  if ((tid < 4) && ((tid & 3) == 0)) {
    global_smem[tid] = shared_val;
  }
  __syncthreads();

  output[0] = global_smem[0];
}

Observations:

  • tl.sum() is lowered to a intra-warp reduction with warp shuffle instruction first, followed by inter-warp reduction across warps within shared memory;
  • tl.sum() performs butterfly reduction instead of tree reduction, so every thread has the reduced value;

However, the 2nd __syncthreads() is not necessary here, and we only need thread 0 to write out data.

Now let’s see what happens if with inter-CTA synchronization. Here we use multi CTAs to reduce a tensor to a float, each CTA only has one warp:

@triton.jit
def kernel_230424_04(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
    BLOCK_SIZE_N: tl.constexpr,
):
    pid = tl.program_id(0)
    cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    mask = cols < N
    x = tl.load(x_ptr + cols, mask=mask, other=0.0)
    total = tl.sum(x, axis=0)
    tl.atomic_add(y_ptr, total)

def launch_kernel_230424_04():
    n = 128
    x = torch.randn((n, ), device='cuda')
    y = torch.zeros((1, ), device='cuda')  # has to be zeros
    BLOCK_SIZE_N = 32
    grid = triton.cdiv(n, BLOCK_SIZE_N)

    kernel_230424_04[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)
    assert torch.allclose(x.sum(), y)

After reverse engineering the PTX, we got following CUDA code:

__global__ void kernel_230424_04(float *input, float *output) {
  __shared__ float global_smem[32];

  int tid = threadIdx.x;
  int ctaid = blockIdx.x;
  int globalIdx = (ctaid << 5) | tid;

  if (globalIdx < 128) {
    float temp = input[globalIdx];

    // Perform butterfly shuffle operations and sum the values
#pragma unroll
    for (unsigned N = 16; N > 0; N >>= 1) {
      temp += __shfl_xor_sync(0xFFFFFFFF, temp, N, 0x1F);
    }

    if (tid == 0) {
      global_smem[0] = temp;
    }
    __syncthreads();

    if (tid < 1) {
      global_smem[tid] = global_smem[0];
    }
    __syncthreads();

    if (tid == 0) {
      atomicAdd(output, global_smem[0]);
    }
  }
}

Observations:

  • Atomic operations like tl.atomic_add() is translated to CUDA atomicAdd() directly;
  • Triton can be optimized here, since shared memory and __syncthreads() are not required here;

Memory Link to heading

User can’t control whether the loaded data stored, shared memory or register file. However, there are rules that we infer where the data is stored.

Let write a simple transpose kernel with only 1 CTA and 1 warp:

@triton.jit
def kernel_230424_05(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
):
    off = tl.arange(0, N)
    x_ptrs = x_ptr + off[:, None] * N + off[None, :]
    y_ptrs = y_ptr + off[None, :] * N + off[:, None]
    x = tl.load(x_ptrs)
    tl.store(y_ptrs, x)

def launch_kernel_230424_05():
    n = 16
    x = torch.randint(10, (n, n), device='cuda').float()
    y = torch.empty((n, n), device='cuda')  # has to be zeros

    kernel_230424_05[(1,)](x, y, n, num_warps=1)
    assert torch.allclose(x.t(), y)

Reverse engineering the PTX code and we get its corresponding CUDA:

__global__ void kernel_230424_05(float *input, float *output) {
  __shared__ float global_smem[17 * 17];

  int tid = threadIdx.x;
  int row = (tid >> 2) & 0xF;
  int col = (tid << 2) & 0xC;
  int off1 = (row << 4) + col;
  int off2 = (row << 4) + 128 + col;

  if (tid < 32) {
    int smemOffset = col * 17 + row;
    float4 inData1 = *((float4 *)(input + off1));
    float4 inData2 = *((float4 *)(input + off2));

#pragma unroll(4)
    for (int i = 0; i < 4; i++) {
      *(global_smem + smemOffset + i * 17) = *((float *)&inData1 + i);
    }
#pragma unroll(4)
    for (int i = 0; i < 4; i++) {
      *(global_smem + smemOffset + 8 + i * 17) = *((float *)&inData2 + i);
    }
  }
  __syncthreads();

  if (tid < 32) {
    int smemOffset = row * 17 + col;
    float4 outData1, outData2;
#pragma unroll(4)
    for (int i = 0; i < 4; i++) {
      *((float *)&outData1 + i) = *(global_smem + smemOffset + i);
    }
#pragma unroll(4)
    for (int i = 0; i < 4; i++) {
      *((float *)&outData2 + i) = *(global_smem + smemOffset + 136 + i);
    }

    *((float4 *)(output + off1)) = outData1;
    *((float4 *)(output + off2)) = outData2;
  }
}

As we can see from above code, for the example of transposing 16x16 tile with 32 threads, Triton infers each thread can use ldg.128, so the CTA has to split the data into two tiles, and those two tiles are unrolled. Triton hold the intermediate data in shared memory and use it for data transpose.

Observations:

  • Triton put the loaded tile in register file by default;

  • Triton will put the tile into shared memory if 3 situations:

    1. The operator requires shared memory: e.g. tl.sum() for reduction;
    2. Layout transformation is required: e.g. transpose operation;
    3. The operator requires shared memory operand: e.g. tl.dot() for GEMM;

Instruction Link to heading

Now let’s check if Triton generates fast math instruction with a sigmoid function:

@triton.jit
def kernel_230424_06(
    x_ptr,
    y_ptr,
    N: tl.constexpr,  # static shape
):
    off = tl.arange(0, N)
    x = tl.load(x_ptr + off)
    x = tl.sigmoid(x)
    tl.store(y_ptr + off, x)

def launch_kernel_230424_06():
    n = 32
    x = torch.randn((n,), device='cuda')
    y = torch.empty((n,), device='cuda')

    kernel_230424_06[(1,)](x, y, n, num_warps=1)
    assert torch.allclose(x.sigmoid(), y)

Its corresponding CUDA kernel is:

__global__ void kernel_230424_06(float *input, float *output) {
  int tid = threadIdx.x;
  int idx = tid & 31;

  if (tid < 32) {
    float inValue = input[idx];
    output[idx] = 1.0f / (1.0f + expf(-inValue));
  }
}

However, the generated PTX has some differences:

// nvcc, no fast math                     | triton                              |  nvcc, fast math
ld.global.f32   %f1, [%rd6];              | @%p1 ld.global.b32 {%r1}, [%rd1+0]; | ld.global.f32   %f1, [%rd6];
neg.f32     %f2, %f1;                     | mov.b32     %f3, %r1;               | mul.f32     %f2, %f1, 0fBFB8AA3B;
mov.f32     %f3, 0f3F000000;              | mov.f32     %f4, 0f00000000;        | ex2.approx.f32  %f3, %f2;
mov.f32     %f4, 0f3BBB989D;              | sub.f32     %f5, %f4, %f3;          | add.f32     %f4, %f3, 0f3F800000;
fma.rn.f32  %f5, %f2, %f4, %f3;           | mul.f32     %f2, %f5, 0f3FB8AA3B;   | rcp.approx.f32  %f5, %f4;
mov.f32     %f6, 0f3FB8AA3B;              | ex2.approx.f32 %f1, %f2;            | cvta.to.global.u64  %rd7, %rd2;
mov.f32     %f7, 0f437C0000;              | add.f32     %f6, %f1, 0f3F800000;   | add.s64     %rd8, %rd7, %rd5;
cvt.sat.f32.f32     %f8, %f5;             | mov.b32     %r4, %f6;               | st.global.f32   [%rd8], %f5;
mov.f32     %f9, 0f4B400001;              | mov.u32     %r3, 1065353216;        |
fma.rm.f32  %f10, %f8, %f7, %f9;          | div.full.f32 %r5, %r3, %r4;         |
add.f32     %f11, %f10, 0fCB40007F;       | add.s64     %rd2, %rd4, %rd5;       |
neg.f32     %f12, %f11;                   | @%p1 st.global.b32 [%rd2+0], {%r5}; |
fma.rn.f32  %f13, %f2, %f6, %f12;         |
mov.f32     %f14, 0f32A57060;             |
fma.rn.f32  %f15, %f2, %f14, %f13;        |
mov.b32     %r3, %f10;                    |
shl.b32     %r4, %r3, 23;                 |
mov.b32     %f16, %r4;                    |
ex2.approx.ftz.f32  %f17, %f15;           |
fma.rn.f32  %f18, %f17, %f16, 0f3F800000; |
rcp.rn.f32  %f19, %f18;                   |
cvta.to.global.u64  %rd7, %rd2;           |
add.s64     %rd8, %rd7, %rd5;             |
st.global.f32   [%rd8], %f19;             |

Observations:

  • Triton uses fast math for exp(), but doesn’t use fast math for rcp();

Tensor Core Link to heading

@triton.jit
def kernel_230429_01(
    a_ptr,
    b_ptr,
    d_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
):
    offs_m = tl.arange(0, M)
    offs_n = tl.arange(0, N)
    offs_k = tl.arange(0, K)
    a_ptrs = a_ptr + (offs_m[:, None] * K + offs_k[None, :])
    b_ptrs = b_ptr + (offs_k[:, None] * N + offs_n[None, :])
    d_ptrs = d_ptr + (offs_m[:, None] * N + offs_n[None, :])
    a = tl.load(a_ptrs)
    b = tl.load(b_ptrs)
    d = tl.dot(a, b)
    tl.store(d_ptrs, d)

def launch_kernel_230429_01():
    m, n, k = 16, 16, 16
    a = torch.randn((m, k), device='cuda', dtype=torch.float16)
    b = torch.randn((k, n), device='cuda', dtype=torch.float16)
    d = torch.empty((m, n), device='cuda', dtype=torch.float16)

    kernel_230429_01[(1,)](a, b, d, m, n, k, num_warps=1)
    assert torch.allclose(a @ b, d, atol=1e-2, rtol=0)