Slang で自動微分シェーダ開発入門

Slang で自動微分シェーダ開発入門

はじめに

こんにちは。シリコンスタジオ 研究開発室の川口です。

突然ですが、皆さんは自動微分や微分可能レンダリングといった言葉を耳にしたことはありますか?
機械学習に触れたことがある方には、既に馴染みのある分野かもしれません。
オフラインレンダリングの分野では Mitsuba3 という学術用途の微分可能レンダラーが有名で、多くの研究で利用されています。

しかし、ゲーム開発をはじめとするリアルタイムレンダリングの分野では、まだあまり関わりのない概念かもしれません。

自動微分と微分可能レンダリングについて

さて、自動微分や微分可能レンダリングとはなんでしょう。

調べると、主に機械学習の分野でたくさんの情報が見つかると思います。
詳しい説明はそれらに譲るとして、ここではざっくりとした説明だけにとどめます。

自動微分とは、プログラムのアルゴリズムの一つで、実装された処理の微分形を自動的に得る処理を指します。
例えば $f(x)=x^2$ という関数を考えたときに、この導関数(微分)は $f'(x)=2x$ です。
これをプログラムで表した場合

float f(float x) { return x*x; }
float f_prime(float x) { return 2*x; }

となります。
自動微分の仕組みがあるプログラムでは float f(float x) だけを定義して、微分が必要なときに AD(f)(x) とするだけで float f_prime(float x) にあたる処理が自動的に生成されます。  
AD(f)(x) は適当な疑似コードで AD(f) (auto differential) が f を微分した関数を表すものとします。)
簡単な式の微分にはあまり意味があるようには見えませんが、手で計算しろといわれて困るような複雑な数式でも、分岐やループのようなプログラム特有の処理でも、自動微分のプログラミングでは記述する処理全ての微分形が得られます。

プログラムの微分ができて嬉しいこととして、最適化問題を効率的に解けるようになるという点が挙げられます。
最適化問題を解くためには、基本的に損失関数の勾配(偏微分)が必要になります。
複雑な問題の勾配は計算が大変ですが、自動微分の仕組みがあればその手間の多くを省け、計算効率もよくなる可能性があります。

その一例がレンダリングです。
レンダリングの処理は、シーンのパラメータ(メッシュや光源、カメラ、マテリアルなど)を入力に画像を出力とする一つの巨大な関数とみなすことができます。
シーンの一つのパラメータ(例えば光源の色の1成分など)について画像がどのように変化するか、というのがレンダリングの微分(e.g. $\frac{\partial レンダリング画像}{\partial 光源Aの色のRチャンネル}$)です。
無数にあるシーンのパラメータと膨大な画素数のレンダリング画像について、これをすべて微分してくださいと言われても人間には不可能ですし、単なる数値計算手法でも高次すぎて手に負えません。

自動微分プログラミングでレンダラーを実装すると、自動的にレンダリング処理の微分が得られ、レンダリング処理に関する最適化問題が解けるようになります。
(例えば、レンダリング関数 $f(x)$ について、入力画像 $I$ との誤差関数 $E=\|f(x) – I\|^2$ の最適化問題を $\frac{\partial E}{\partial x}=0$ を使って解けるようになる。)  

これが微分可能レンダリングで、NeRF や Gaussian Splatting に代表される画像から3Dシーンを再構成する技術に使われています。

Slang について

少し話は変わって、昨年末の SIGGRAPH Asia 2024 で Slang というシェーダ言語についての講演があり注目を集めました。
https://www.khronos.org/news/press/khronos-group-launches-slang-initiative-hosting-open-source-compiler-contributed-by-nvidia

Slang は NVIDIA が SIGGRAPH 2018 で発表した、オープンソースのリアルタイムグラフィクス向けのシェーダ言語です。
http://graphics.cs.cmu.edu/projects/slang/
グラフィクス API やプラットフォーム、GPU が多様化している昨今の環境で、大規模で複雑かつ高パフォーマンスを期待されるシェーダ開発を支える新たな共通基盤として Slang は開発されています。  
Slang には D3D12, Vulkan, Metal, D3D11, OpenGL, CUDA, さらには CPU で動作するコードを生成する仕組みが用意されています。

SIGGRAPH Asia 2023 では Slang を微分可能に拡張する取り組みが紹介され、微分可能レンダリングの実装環境として Slang という選択肢が上がってきました。
https://research.nvidia.com/labs/rtr/publication/bangaru2023slangd/

2024年11月に Slang のオープンソースプロジェクトが Khronos Group へ移管したことが発表され、多くの Slang の機能のアップデートも紹介されました。  
https://shader-slang.org/blog/2024/11/20/theres-a-lot-going-on-with-slang/
SIGGRAPH Asia 2024 の講演もその一環で Slang の役割や自動微分をはじめとする機能についての紹介がありました。

2025年1月にあった NVIDIA の RTX 5000 番台 GPU の発表の中でも、RTX Neural Shaders という新しい取り組みが紹介され、そこでも Slang を利用すると言及されています。
https://developer.nvidia.com/blog/nvidia-rtx-neural-rendering-introduces-next-era-of-ai-powered-graphics-innovation/
https://github.com/NVIDIA-RTX/RTXNS

近年の AI 技術の発展は CG 分野でも非常に活発で、オフラインレンダリングに留まらずリアルタイムレンダリングへも波及しつつあります。
その開発を支える言語として Slang は今後重要なツールになってくると予想されます。

私のようなゲームなどのリアルタイムレンダリングに関わっている人間からすると、Slang は HLSL に非常に近い構文を持つため、ぱっと見てとっつきやすい言語ではあります。
しかし、そこに自動微分や微分可能レンダリングといった新しい概念が入り込んでくると、途端にリアルタイムレンダリングの文脈から外れて従来の知識ではうまく扱えなくなると感じました。

そこで今回は、この Slang を使って自動微分の仕組みがあるシェーダでどのようなことができるのか、今までのシェーダ開発から何が変わるのかを検証していきたいと思います。
私自身が勉強として試した内容を紹介していくので、非効率だったり間違っていたりする内容を含むかもしれませんがご了承ください。

皆さんの Slang 開発の入門として一つのきっかけになれば幸いです。

Slang で自動微分シェーダ開発入門

前置きが長くなりましたが、私が試した Slang での開発について色々紹介していきます。
今回は自動微分の活用に重きを置いています。
微分可能レンダリングについては、また別の機会に紹介したいと考えています。

また Slang は HLSL をほとんどそのまま書くことができますが、HLSL に無い便利な言語機能もたくさんあるので、それらも極力活用しています。

リアルタイムレンダリングは基本的にメッシュレンダリングで構成されていますが、今回はシンプルなコードで完結させるため、Shadertoy などでよく使われる、メッシュを入力に取らない 1 パスで完結するレンダリングを取り上げます。
直接ゲームなどに応用はできないかもしれませんが、Slang の構文の理解や自動微分という考えを理解するにはよいサンプルになると思います。

執筆時点の環境(2025年3月)

  • Slang (v2025.6.1)

Slang は日々アップデートが続いているので、リリース時期や動かす環境によって結果が変わったり動かなくなったりする可能性があります。
今回は Slang で Compute Shader を記述して、それを HLSL で出力し Unity6 上で動作させています。
Unity である必要は全くないので、動かす際はそれぞれの環境で試してみてください。

Slang を Shadertoy のようにブラウザ上で動作させる Playground というデモ環境もあります。
今回のコードも少し修正すれば Playground で試すこともできるでしょう。

レイマーチングに自動微分を活用する

まず Slang で球を描画するシンプルなレイマーチングを実装してみます。

Slang の構文

Slang の構文は基本的に HLSL と同じです。
そこに C# や C++、Rust などにあるような便利な言語機能が追加されています。

自動微分を使う前に、いくつかの構文についても紹介しておきましょう。
HLSL と同様に Slang はオブジェクト指向言語ではないのでクラスの概念はありませんが、オブジェクト指向のような便利な構文が使えます。
(実は HLSL にも interface や template の仕組みがありますが、使える環境が制限されるなどの要因からか活用されている印象はあまりありません。)

Constructor

コンストラクタは構造体のメンバを初期化する仕組みで、__init() と宣言します。
独自の型の定義と初期化を簡潔に記述するのに利用できるでしょう。  

レイマーチングに必要なレイを定義しました。

struct Ray
{
    float3 origin;
    float3 direction;

    __init(float3 p, float3 d)
    {
        origin = p;
        direction = d;
    }

    float3 evalPos(float t)
    {
        return origin + t * direction;
    }
};

Interface

インターフェースは共通の機能を持つ構造体を定義するための仕組みです。
レイマーチングにおける描画オブジェクトは SDF(符号付き距離関数)で定義されるので、それをインターフェースを使って実装してみました。
SDF には共通して距離を評価する関数が必要なので、インターフェースで dist メンバ関数を持つように定義しています。
ここでは球だけですが、他の SDF を使いたいときも ISDF インターフェースを使うことで簡潔でエラーの起きにくい実装が期待できます。
ISDF を使って複数種類の SDF を定義した実装は、本記事の最後のおまけサンプルに使っています。

interface ISDF
{
    float dist(float3 p);
};
struct Sphere : ISDF
{
    float3 center;
    float radius;

    __init(float3 _center, float r)
    {
        center = _center;
        radius = r;
    }

    float dist(float3 p)
    {
        return length(p - center) - radius;
    }
};

Generics

ジェネリクスは C++ でいうところのテンプレートで、異なる型で共通のロジックを利用するための仕組みです。
ジェネリック型パラメータ T はインターフェースで制約することができます。
ここでは SDF の値と法線を評価する関数をジェネリクスで実装しています。

float getDistance<T : ISDF>(float3 p, T sdf)
{
    return sdf.dist(p);
}

float3 getNormal<T : ISDF>(float3 p, T sdf) {
    float2 d = float2(0, 1e-4);
    return normalize(float3(
        sdf.dist(p + d.yxx) - sdf.dist(p - d.yxx),
        sdf.dist(p + d.xyx) - sdf.dist(p - d.xyx),
        sdf.dist(p + d.xxy) - sdf.dist(p - d.xxy)));
}

レイマーチングの実装

球体を描画するシンプルなレイマーチングの本体を実装します。

RWTexture2D<float4> result;
uniform float4 _TexelSize; // x:1/width, y:1/height, z:width, w:height

static const int NUM_STEP = 100;
static const float THRESHOLD = 1e-6;

Ray computeRay(uint2 tid)
{
    float2 pixelCoord = (2 * (tid.xy + 0.5) - _TexelSize.zw) * _TexelSize.y;
    return Ray(float3(0, 0, -2.), normalize(float3(pixelCoord, 1)));
}

[shader("compute")]
[numthreads(32, 32, 1)]
void simple_sphere(uint3 threadId: SV_DispatchThreadID)
{
    var ray = computeRay(threadId.xy);

    float3 lightDir = normalize(float3(-0.5, 0.4, -0.6));

    var sphere = Sphere(float3(0, 0, 0), 1);

    float t = 0;
    float3 color = 0.0;
    for (int i = 0; i < NUM_STEP; i++)
    {
        float3 pos = ray.evalPos(t);
        float d = getDistance(pos, sphere);

        if (abs(d) < THRESHOLD * t)
        {
            float3 n = getNormal(pos, sphere);

            color = pow(dot(lightDir, n) * 0.5 + 0.5, 4); // half Lambert

            break;
        }
        t += d;
    }

    result[threadId.xy] = float4(color, 1.0);
}

Slang ではシェーダのエントリポイントになる箇所に [shader("compute")] と宣言を付けます。
compute 以外にも vertexfragment、レイトレシェーダの raygeneration, closesthit など各種シェーダステージが存在しています。

computeRay 関数でピクセルの位置毎にレイを定義してレイマーチングを行います。
球の SDF を評価して衝突位置とその位置における法線を計算し、適当な Half-Lambert 風のシェーティングを行っています。

レイマーチングに自動微分を使う

さて、レイマーチングにおける法線を今は getNormal 関数で評価しています。

float3 getNormal<T : ISDF>(float3 p, T sdf) {
    float2 d = float2(0, 1e-4);
    return normalize(float3(
        sdf.dist(p + d.yxx) - sdf.dist(p - d.yxx),
        sdf.dist(p + d.xyx) - sdf.dist(p - d.xyx),
        sdf.dist(p + d.xxy) - sdf.dist(p - d.xxy)));
}

SDF (距離場)で定義される3次元形状は、距離値 0 が表面を定義し距離値が大きくなる方向(=勾配)がその面の法線を意味します。
getNormal 関数は、現在の位置から少しずらした位置の距離値の差分を取ることで距離場の勾配を計算する数値微分です。

この数値微分は自動微分に置き換えることができそうです。
やってみましょう。

Differentiable

Slang で自動微分を利用したい関数には [Differentiable] 属性を付けます。
今回は距離を評価する関数の勾配が欲しいので、dist 関数と getDistance 関数に属性をマークします。

そして計算中に微分しない(定数として扱う)パラメータには no_diff 修飾子を付けます。
getDistance の引数である T sdf は形状を定義する定数(Sphere が中心座標位置 center と半径 radius を持っているだけ)なので no_diff 修飾しておきます。
(あくまでパラメータだけの話なのでメンバ関数 dist の微分が必要なのとは関係ありません。)

struct Sphere : ISDF
{
    float3 center;
    float radius;

    __init(float3 _center, float r)
    {
        center = _center;
        radius = r;
    }

    [Differentiable]
    float dist(float3 p)
    {
        return length(p - center) - radius;
    }
};

[Differentiable]
float getDistance<T : ISDF>(float3 p, no_diff T sdf)
{
    return sdf.dist(p);
}

自動微分の利用

自動微分には、その計算方法に forward-mode と reverse-mode の二通りがあります。
それらの説明についてはここでは省略します。
詳しくは Slang のユーザガイドを参照してください。
https://shader-slang.org/slang/user-guide/autodiff.html

自動微分を自分で実装してみる、という試みを行っている記事もたくさんあるようなので、それらを探して勉強するのも良いと思います。(私も勉強中です。)

自動微分がどう実現されているかは置いておいて、利用する側からは、forward-mode と reverse-mode をどのように使い分けるかが重要です。

一般的に、関数の入力の次元が出力の次元よりも大きいときは、reverse-mode を使った方が効率がよいようです。
今回は、入力が3次元座標で出力が1次元距離値の getDistance 関数を微分したいので reverse-mode を使います。

float3 getNormalAD<T : ISDF>(float3 p, T sdf)
{
    DifferentialPair<float3> diffPos = diffPair(p);
    bwd_diff(getDistance)(diffPos, sdf, 1.0);

    return normalize(diffPos.d.xyz);
}

まず DifferentialPairdiffPair で微分パラメータを宣言します。
DifferentialPair は 関数の入力とその微分結果を保持する型です。
今回は距離関数を評価する3次元座標が入力なので p について diffPair を宣言します。

reverse-mode の自動微分は bwd_diff で行います。
その結果が diffPos に格納され、diffPos.p には元の入力の値(primal part)、diffPos.d には微分した値(derivative part)が入ります。

この diffPos.d こそが距離関数を微分して得られる勾配で、SDF が定義する形状の法線となります。

ちなみに

float getDistance<T : ISDF>(float3 p, no_diff T sdf)

この関数を自動微分したものは

void bwd_diff(getDistance) (inout DifferentialPair<float3> p, in T sdf, in float resultGradient)

のような bwd_diff(getDistance) という名前の新しい関数のように見ることができます。

ここで最終引数の resultGradient には 1 を入れます。
正確な説明ではありませんが、これは関数の出力ベクトルにかかるマスク(あるいはスケール)のように機能します。
Slang の自動微分のユーザガイドによると、resultGradient は reverse-mode の vector-Jacobian product の vector 部分に対応します。
入出力の次数が多い関数の微分では Jacobian 行列が大きくなるので、必要な要素だけを計算する為に必要という風に私は理解しています。

今回は単純に getDistance の関数の出力は1次元なので、resultGradient は1次元 float になり、1 を指定します。

余談

ところで、getDistance 関数は ISDFdist メンバ関数を呼ぶだけなので冗長だと思った方もいるかもしれません。
しかし、次のように直接 dist 関数の微分を利用しようとするとエラーが出たため、このような実装になっています。
メンバ関数を直接微分するような実装はできないようです(うまくやる方法もあるかもしれません)。

// bwd_diff(sdf.dist)(diffPos, 1.0); // error 30098: non-static function reference 'dist' is not allowed here.

結果

レイマーチングで法線を計算する処理を自動微分のバージョンに置き換えます。

float3 n = getNormalAD(pos, sphere);

color = n * 0.5 + 0.5; で法線を可視化したものも併せて、最終的な結果を示します。
左が数値微分、右が自動微分の結果で、目で見る限り同等の絵が得られています。
これで自動微分を使って距離場の法線(勾配)を計算できることがわかりました。

自動微分を使ったアンチエイリアシング

自動微分の活用としてもう一例試してみましょう。

レンダリングにおいて、微分を利用している物としてミップマップが上げられます。
遠いオブジェクトに詳細なテクスチャを貼ると必要以上に負荷がかかったり、高周波成分がアーティファクトとして現れてしまいます。
ミップマップは、その負荷やアーティファクトを改善するための技術で、テクスチャをサンプリングするときに隣接ピクセル間の座標の変化量に基づいてサンプリングするテクスチャの解像度を変更するというものです。
ミップマップの計算に使う隣接ピクセル間の座標の変化量こそ微分を意味していて、隣接ピクセルで変化量が大きい=視点から遠いということになり、変化量に応じてテクスチャの解像度を変えることでアンチエイリアシングの効果が得られます。  

今回はテクスチャのサンプリングではありませんが、数値的に作るチェッカーボードのパターンにミップマップの代替となる解析的なフィルタリングを施してみます。
この内容はこちらの Inigo Quilez 氏の記事とその Shadertoy 実装に基づいており、細かい処理や計算の説明はここでは省略します。
https://iquilezles.org/articles/checkerfiltering/
https://www.shadertoy.com/view/XlcSz2

この例では数値的なプリミティブ(球と平面)を使ったシェーダ内だけで完結するレイトレーシングでレンダリングします。

微分可能なレイ

前述の通り、ミップマップに必要な微分とは、ピクセル座標の変化に対する UV の変化量です。
なのでピクセル座標を入力に UV が得られる関数を設計する必要がありそうです。

まずレイマーチングと同様にレイを定義しますが、今回はレイ自体が微分可能である必要があります。
微分可能なレイ DRay とその初期化を行う関数を用意します。
レイの位置や(今は time=0 として動かしていませんが)アニメーションなどの設定はオリジナルの Shadertoy 実装と同じです(個人的なわかりやすさのため x と z 座標を入れ替えています)。

struct DRay : IDifferentiable
{
    float3 origin;
    float3 direction;

    [Differentiable]
    __init(float3 o, float3 d)
    {
        origin = o;
        direction = d;
    }

    [Differentiable]
    float3 evalPos(float t)
    {
        return origin + t * direction;
    }
};

[Differentiable]
DRay computeDRay(float2 p, no_diff float time)
{
    float2 mo = float2(0.0, 0.0);

    // (original shadertoy) calcCamera
    float an = 0.1 * sin(0.1 * time);
    float3 ro = float3(5.0 * sin(an), 0.5, 5.0 * cos(an));
    float3 ta = float3(0.0, 1.0, 0.0);

    // camera matrix
    float3 ww = normalize(ta - ro);
    float3 uu = normalize(cross(ww, float3(0.0, 1.0, 0.0)));
    float3 vv = normalize(cross(uu, ww));

    // create view ray
    float3 rd = normalize(p.x * uu + p.y * vv + 2.0 * ww);

    return DRay(ro, rd);
}

ここで注目するのは IDifferentiable というインターフェースです。
Slang ではユーザ定義型を自動微分のパラメータとして指定するために IDifferentiable インターフェースを利用します。
ユーザガイドを読むと色々と複雑な利用方法もあるようですが、ここでは微分したい処理の中で使う構造体には IDifferentiable を付ける、程度の理解で大丈夫です。
構造体のメンバ関数には [Differentiable] 属性を付けておきます。
また、微分に関わらないメンバ変数には no_diff キーワードを付けることもできます。

シーンオブジェクト

今回のシーンは平面と球体だけで構成されています。
レイトレのヒット情報と球を定義します。  

struct HitInfo : IDifferentiable
{
    float hitT;
    float3 position;
    float3 normal;
    float2 uv;
    float occlusion;
    int hitId;

    [Differentiable]
    __init()
    {
        hitT = -1.0;
        position = 0.0;
        normal = 0.0;
        uv = 0.0;
        occlusion = 1.0;
        hitId = -1;
    }
};
struct Sphere
{
    float3 center;
    float radius;

    __init(float3 _c, float _r)
    {
        center = _c;
        radius = _r;
    }

    [Differentiable]
    float intersect(DRay ray)
    {
        float t = -1.0;
        float3 ce = ray.origin - center;
        float b = dot(ray.direction, ce);
        float c = dot(ce, ce) - radius * radius;
        float h = b * b - c;
        if (h > 0.0)
        {
            t = -b - sqrt(h);
        }
        return t;
    }

    [Differentiable]
    HitInfo evalHitinfo(DRay ray, float t, int id)
    {
        float3 pos = ray.evalPos(t);
        HitInfo hitinfo;
        hitinfo.hitT = t;
        hitinfo.position = pos;
        hitinfo.normal = normalize(pos - center);
        hitinfo.uv = evalUV(pos);
        hitinfo.occlusion = 0.5 + 0.5 * hitinfo.normal.y;
        hitinfo.hitId = id;
        return hitinfo;
    }

    [Differentiable]
    float2 evalUV(float3 pos)
    {
        float3 q = normalize(pos - center);
        return float2(atan2(q.x, q.z), acos(q.y)) * radius * 12.0;
    }

    float evalOcclusion(HitInfo hitinfo)
    {
        float3 di = center - hitinfo.position;
        float l = length(di);
        return 1.0 - dot(hitinfo.normal, di / l) * radius * radius / (l * l);
    }

    float evalSoftShadow(DRay ray)
    {
        float3 oc = center - ray.origin;
        float b = dot(oc, ray.direction);

        float res = 1.0;
        if (b > 0.0)
        {
            float h = dot(oc, oc) - b * b - radius * radius;
            res = smoothstep(0.0, 1.0, 2.0 * h / b);
        }
        return res;
    }
};

HitInfo は微分する対象なので IDifferentiable を付け、Sphere は微分しない定数なので IDifferentiable を付けていません。
メンバ関数の必要な物には [Differentiable] を付けています。
このあたりは全体の設計が分かっていないと、どれが微分に必要なのか把握できないため、実際の実装では適宜エラー文を確認しつつ修飾を記述していく必要がありました。

Occlusion や SoftShadow の実装も Slang の言語機能を(無理に?)活用したかったので、Sphere のメンバとして実装しています。
オリジナルの Shadertoy 実装と同じ結果になるように本質的な計算は変えてませんが、非効率になっている部分もあるかもしれません。
自動微分に限らず Slang の生成するコードがどれだけ効率的に実行できているのか、どのように記述するとパフォーマンス的に優秀なのか、そんな検証も機会があればやっていきたいです。

レイトレの実装

オブジェクトとレイの交差処理を実装します。
平面と配列で持った Sphere の交差判定を行い、HitInfo にヒットした位置の情報を格納します。

[Differentiable]
float2 texcoord_plane(float3 pos)
{
    return pos.xz * 12.0;
}
[Differentiable]
float intersect_plane(DRay ray, float height)
{
    return (height - ray.origin.y) / ray.direction.y;
}
[Differentiable]
HitInfo intersect<let N : int>(float2 pixPos, no_diff float time, no_diff Array<Sphere, N> spheres)
{
    float t_min = 1e6;
    DRay ray = computeDRay(pixPos, time);
    HitInfo hitinfo;

    // raytrace-plane
    float h = intersect_plane(ray, 0.01);
    if (h > 0.0)
    {
        t_min = hitinfo.hitT = h;
        hitinfo.position = ray.evalPos(h);
        hitinfo.normal = float3(0.0, 1.0, 0.0);
        hitinfo.uv = texcoord_plane(hitinfo.position);
        hitinfo.occlusion = 1;
        hitinfo.hitId = -1;
    }

    // raytrace-sphere
    [ForceUnroll]
    for (int i = 0; i < N; i++)
    {
        h = spheres[i].intersect(ray);
        if (h > 0.0 && h < t_min)
        {
            t_min = h;
            hitinfo = spheres[i].evalHitinfo(ray, h, i);
        }
    }

    return hitinfo;
}

描画する球の数を簡単に変えられるように配列にしています。
一般的なシェーダと同様に、Slang でも動的な配列は扱えないので固定サイズですが、ジェネリクスの値パラメータ <let N : int> を使って少しスマートに記述してみました。
C++ における非型テンプレートパラメータ (non-type template parameter) のような機能です。  

試した限りでは unroll できない for 文は自動微分の対象にできないようで、配列サイズの for 文を unroll するために値パラメータを採用しました。
値パラメータを使わないのであれば最大の要素数を適当に制限して次のようにも書けますが、ちょっと非効率になりそうです。

[Differentiable]
HitInfo intersect(float2 pixPos, no_diff float time, no_diff Sphere spheres[])
{
    ...
    [MaxIters(10)]
    for (int i = 0; i < spheres.getCount(); i++)
    { ... }
    ...
}

チェッカーボードパターン

チェッカーボードパターンの実装は、HLSL と GLSL の違い以外にはオリジナルから変更していません。
微分を使った解析的なフィルタリングの詳しい仕組みは元記事をご参照ください。

チェッカーボードの解析的なフィルタリングということで、実践的なテクスチャのサンプリングには直接使えないようにも思いますが、ミップマップをシミュレートした概念としてはわかりやすい例ではあると思います。

// --- analytically box-filtered checkerboard ---
float checkersTextureGradBox(float2 p, float2 ddx, float2 ddy)
{
    // filter kernel
    float2 w = max(abs(ddx), abs(ddy)) + 0.01;
    // analytical integral (box filter)
    float2 i = 2.0 * (abs(frac((p - 0.5 * w) / 2.0) - 0.5) - abs(frac((p + 0.5 * w) / 2.0) - 0.5)) / w;
    // xor pattern
    return saturate(0.5 - 0.5 * i.x * i.y);
}

// --- unfiltered checkerboard ---
float glsl_mod(float x, float y)
{
    return x - y * floor(x / y);
}
float checkersTexture(float2 p)
{
    float2 q = floor(p);
    return saturate(glsl_mod(q.x + q.y, 2.0)); // xor pattern
}

本体の実装

残るライティングとエントリポイント(raydifferential)の実装を行います。

今回は画面を3分割して結果を比較できるようにしました。
左から、元実装のフィルタリング(checkersTextureGradBox)、フィルタリングなし(checkersTexture)、自動微分を使ったフィルタリング(checkersTextureGradBox)としています。

float3 doLighting<let N : int>(HitInfo hitinfo, Array<Sphere, N> spheres)
{
    DRay lightingRay = DRay(hitinfo.position, float3(0.57703, 0.57703, 0.57703));

    float sh = 1.0;
    float oc = hitinfo.occlusion;
    [ForceUnroll]
    for (int i = 0; i < N; i++)
    {
        sh = min(sh, spheres[i].evalSoftShadow(lightingRay));
        oc *= spheres[i].evalOcclusion(hitinfo);
    }

    float dif = clamp(dot(hitinfo.normal, lightingRay.direction), 0.0, 1.0);
    float bac = clamp(0.5 + 0.5 * dot(hitinfo.normal, float3(-0.707, 0.0, -0.707)), 0.0, 1.0);
    float3 lin = dif * float3(1.50, 1.40, 1.30) * sh;
    lin += oc * float3(0.15, 0.20, 0.30);
    lin += bac * float3(0.10, 0.10, 0.10) * (0.2 + 0.8 * oc);

    return lin;
}

RWTexture2D<float4> result;
uniform float4 _TexelSize; // 1/width, 1/height, width, height
uniform float _Time;

[shader("compute")]
[numthreads(32, 32, 1)]
void raydifferential(uint3 threadId: SV_DispatchThreadID)
{
    float time = 0; //32.0 + _Time * 1.5;
    float2 pixPos = (2 * (threadId.xy + 0.5) - _TexelSize.zw) * _TexelSize.y;
    float screenWd = _TexelSize.z * _TexelSize.y;

    static const int NUM_SPHERES = 6;
    Array<Sphere, NUM_SPHERES> spheres = 
    {
        Sphere(float3(-1.0, 0.2,  3.0), 0.2),
        Sphere(float3( 1.0, 0.2,  3.0), 0.2),
        Sphere(float3( 0.0, 0.2,  3.0), 0.2),
        Sphere(float3(-5.0, 2.0, -4.0), 2.0),
        Sphere(float3( 5.0, 2.0, -4.0), 2.0),
        Sphere(float3( 0.0, 2.0, -4.0), 2.0),
    };

    // trace
    HitInfo hitinfo = intersect<NUM_SPHERES>(pixPos, time, spheres);

    float3 col = pow(0.1, 2.2); // BG color
    if (hitinfo.hitT > 0.0)
    {
        // shading
        float3 material = 0.0;

        // left: original ver.
        if (pixPos.x < -screenWd / 3.0)
        {
            float2 px = (2 * (threadId.xy + 0.5 + float2(1.0, 0.0)) - _TexelSize.zw) * _TexelSize.y;
            float2 py = (2 * (threadId.xy + 0.5 + float2(0.0, 1.0)) - _TexelSize.zw) * _TexelSize.y;
            DRay dray_dx = computeDRay(px, time);
            DRay dray_dy = computeDRay(py, time);

            // computer ray differentials
            float3 ddx_pos = dray_dx.origin - dray_dx.direction * dot(dray_dx.origin - hitinfo.position, hitinfo.normal) / dot(dray_dx.direction, hitinfo.normal);
            float3 ddy_pos = dray_dy.origin - dray_dy.direction * dot(dray_dy.origin - hitinfo.position, hitinfo.normal) / dot(dray_dy.direction, hitinfo.normal);

            // calc texture sampling footprint
            float2 ddx_uv, ddy_uv;
            if (hitinfo.hitId == -1) // plane
            {
                ddx_uv = texcoord_plane(ddx_pos) - hitinfo.uv;
                ddy_uv = texcoord_plane(ddy_pos) - hitinfo.uv;
            }
            else // sphere
            {
                ddx_uv = texcoord_sphere(ddx_pos, spheres[hitinfo.hitId]) - hitinfo.uv;
                ddy_uv = texcoord_sphere(ddy_pos, spheres[hitinfo.hitId]) - hitinfo.uv;
            }

            material = checkersTextureGradBox(hitinfo.uv, ddx_uv, ddy_uv);
        }
        // right: auto-diff ver.
        else if (pixPos.x > screenWd / 3.0)
        {
            DifferentialPair<float2> dpdx = diffPair(pixPos, float2(1, 0));
            DifferentialPair<float2> dpdy = diffPair(pixPos, float2(0, 1));

            DifferentialPair<HitInfo> duv_dpx = fwd_diff(intersect<NUM_SPHERES>)(dpdx, time, spheres);
            DifferentialPair<HitInfo> duv_dpy = fwd_diff(intersect<NUM_SPHERES>)(dpdy, time, spheres);

            float2 ddx_uv = duv_dpx.d.uv * (2 * _TexelSize.yy);
            float2 ddy_uv = duv_dpy.d.uv * (2 * _TexelSize.yy);

            material = checkersTextureGradBox(hitinfo.uv, ddx_uv, ddy_uv);
        }
        // center: no filtering
        else material = checkersTexture(hitinfo.uv);

        // combine lighting with material
        col = material * doLighting<NUM_SPHERES>(hitinfo, spheres);
    }

    // gamma correction
    col = pow(col, float3(0.4545));

    // border lines
    col *= smoothstep(1.0, 2.0, abs(pixPos.x + (screenWd / 3.0)) / (2.0 * _TexelSize.y));
    col *= smoothstep(1.0, 2.0, abs(pixPos.x - (screenWd / 3.0)) / (2.0 * _TexelSize.y));

    result[threadId.xy] = float4(col, 1.0);
    return;
}

// right: auto-diff ver. で示しているスコープが自動微分を使った本体の実装になります。

今回の例は、入力がピクセル座標で出力が HitInfo (必要なのは UV だけ)であり、出力の次数の方が大きいので forward-mode の自動微分 fwd_diff を使います。

今スクリーン座標の X 方向と Y 方向それぞれについて微分した UV 座標が欲しいので、diffPair でそれぞれの方向を指定した dpdx, dpdy を宣言して intersect 関数の微分系 fwd_diff(intersect<NUM_SPHERES>) の引数として指定します。
reverse-mode では diffPair は入力パラメータだけで宣言しましたが、forward-mode では入力パラメータとそれにかかる係数で宣言します。
今回はこの係数がスクリーンの微分したい方向を表しているため理解しやすいとは思います。

forward-mode では返り値が元の関数の出力を微分したものを表すので、 DifferentialPair<HitInfo> duv_dpxduv_dpy を受け取るように実装します。

HitInfo すべてを微分していますが、ここでは UV である duv_dpx.d.uv, duv_dpy.d.uv だけ使います。
(ヒット位置や法線の微分もなにかに活用できるかもしれません。)

また、微分された値についてオリジナルの実装で float2 px = (2 * (threadId.xy + 0.5 + float2(1.0, 0.0)) - _TexelSize.zw) * _TexelSize.y; のようにスクリーンのピクセルサイズで差分を取っているので、それに合わせて float2 ddx_uv = duv_dpx.d.uv * (2 * _TexelSize.yy); のようにスケールしています。
diffPair 宣言時の係数に入れてしまっても良いかもしれません。

結果

最終的な結果を示します。

左から、オリジナルのフィルタリング、フィルタなし、自動微分を使ったフィルタリングです。
左と右で同等の結果が得られ、真ん中と比べてアンチエイリアシングの効果があることがわかります。
自動微分を使って UV の変化量に応じたフィルタリングを実装できることが確認できました。

まとめ

本記事では Slang を使った自動微分シェーダ開発について二つの例を紹介しました。
私自身 Slang の構文や自動微分についてまだ初学者なので、今回の内容に留まらず勉強を進めて Slang の活用を積極的に発信していきたいと考えています。

特に Slang を使った微分可能レンダリングや、発表されたばかりの Neural Shading には興味があるので、色々検証してみたいです。

自動微分の選び方

今回の実装に際して、forward-mode と reverse-mode がある自動微分の使い方は、とても理解が難しいと感じました。
正直まだどのようなケースでどちらをどう使うべきなのか、簡潔には説明できないです。

reverse-mode では、入力パラメータすべてについて微分した出力が得られ必要な部分を利用します。
SDF の例では、法線としてスカラー場 $f$ の勾配ベクトル $(\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}, \frac{\partial f}{\partial z})$ が欲しかったので、入力の3次元座標 $(x,y,z)$ についてまとめて微分する reverse-mode を使いました。

forward-mode では、必要な入力を指定して、そのパラメータについてだけの微分が得られます。
フィルタリングの例は、ピクセル座標 $p=(p_x,p_y)$ の変化に対する UV 座標 $(u,v)$ の変化量がピクセル座標の方向毎 $(\frac{\partial u}{\partial p_x},\frac{\partial v}{\partial p_x}),(\frac{\partial u}{\partial p_y},\frac{\partial v}{\partial p_y})$ に欲しかったので forward-mode を使いました。

このあたりの実装は一朝一夕では身につかず、自動微分のプログラミングに慣れていくことで上手に利用できるようになるしかないのだろうと思います。

おまけ

フィルタリングの例で参考にさせていただいた Inigo Quilez 氏の Shadertoy 実装から、有名な SDF のサンプルも Slang に移植してみました。

SDF の法線と床のチェッカーボードパターンには今回の自動微分実装を利用しています。
結構長いですが実装全文も載せておきます。

こちらも Slang の構文を活用しようとしすぎて、逆にわかりにくくなってしまっている部分が多々あるかもしれません。
また、今回はどのサンプルも1ファイルで完結させていますが、実践的には Slang のモジュール機能を使って共通の処理を切り分けて複数のファイルに分割すべきでしょう。

RWTexture2D<float4> result;
uniform float4 _TexelSize; // 1/width, 1/height, width, height
uniform float _Time;

// -----------------------------------------------------------------------------------------------
// SDF utils -------------------------------------------------------------------------------------
interface ISDF
{
    float dist(float3 p);
};

[Differentiable]
float getDistance<T : ISDF>(float3 p, no_diff T sdf)
{
    return sdf.dist(p);
}

float3 getNormal<T : ISDF>(float3 p, T sdf) {
    float2 d = float2(0, 1e-4);
    return normalize(float3(sdf.dist(p + d.yxx) - sdf.dist(p - d.yxx),
                            sdf.dist(p + d.xyx) - sdf.dist(p - d.xyx),
                            sdf.dist(p + d.xxy) - sdf.dist(p - d.xxy)));
}

float3 getNormalAD<T : ISDF>(float3 p, T sdf)
{
    DifferentialPair<float3> diffPos = diffPair(p);
    bwd_diff(getDistance)(diffPos, sdf, 1.0);
    return normalize(diffPos.d.xyz);
}
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// iq's SDFs -------------------------------------------------------------------------------------
// https://www.shadertoy.com/view/Xds3zN

[Differentiable]
float dot2(float2 v) { return dot(v, v); }

[Differentiable]
float dot2(float3 v) { return dot(v, v); }

[Differentiable]
float ndot(float2 a, float2 b) { return a.x * b.x - a.y * b.y; }
struct Sphere : ISDF
{
    no_diff float3 center;
    no_diff float radius;

    __init(float3 _center, float r)
    {
        center = _center;
        radius = r;
    }

    [Differentiable]
    float dist(float3 p)
    {
        return length(p - center) - radius;
    }
};

struct Plane : ISDF
{
    no_diff float height;

    __init(float h)
    {
        height = h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        return p.y - height;
    }
};

struct Box : ISDF
{
    no_diff float3 center;
    no_diff float3 b;

    __init(float3 _center, float3 _b)
    {
        center = _center;
        b = _b;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float3 d = abs(p) - b;
        return min(max(d.x, max(d.y, d.z)), 0.0) + length(max(d, 0.0));
    }
};

struct BoxFrame : ISDF
{
    no_diff float3 center;
    no_diff float3 b;
    no_diff float e;

    __init(float3 _center, float3 _b, float _e)
    {
        center = _center;
        b = _b;
        e = _e;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p = abs(p) - b;
        float3 q = abs(p + e) - e;

        return min(min(length(max(float3(p.x, q.y, q.z), 0.0)) + min(max(p.x, max(q.y, q.z)), 0.0),
                       length(max(float3(q.x, p.y, q.z), 0.0)) + min(max(q.x, max(p.y, q.z)), 0.0)),
                       length(max(float3(q.x, q.y, p.z), 0.0)) + min(max(q.x, max(q.y, p.z)), 0.0));
    }
};

struct Ellipsoid : ISDF
{
    no_diff float3 center;
    no_diff float3 r;

    __init(float3 _center, float3 _r)
    {
        center = _center;
        r = _r;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float k0 = length(p / r);
        float k1 = length(p / (r * r));
        return k0 * (k0 - 1.0) / k1;
    }
};

struct Torus : ISDF
{
    no_diff float3 center;
    no_diff float2 t;

    __init(float3 _center, float2 _t)
    {
        center = _center;
        t = _t;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p = p.xzy; // TODO : transformation
        return length(float2(length(p.xz) - t.x, p.y)) - t.y;
    }
};

struct CappedTorus : ISDF
{
    no_diff float3 center;
    no_diff float2 sc;
    no_diff float ra;
    no_diff float rb;

    __init(float3 _center, float2 _sc, float _ra, float _rb)
    {
        center = _center;
        sc = _sc;
        ra = _ra;
        rb = _rb;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p *= float3(1, -1, 1); // TODO : transformation
        p.x = abs(p.x);
        float k = (sc.y * p.x > sc.x * p.y) ? dot(p.xy, sc) : length(p.xy);
        return sqrt(dot(p, p) + ra * ra - 2.0 * ra * k) - rb;
    }
};

struct HexPrism : ISDF
{
    no_diff float3 center;
    no_diff float2 h;
    no_diff internal const float3 k = float3(-0.8660254, 0.5, 0.57735);

    __init(float3 _center, float2 _h)
    {
        center = _center;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float3 q = abs(p);

        p = abs(p);
        p.xy -= 2.0 * min(dot(k.xy, p.xy), 0.0) * k.xy;
        float2 d = float2(
       length(p.xy - float2(clamp(p.x, -k.z * h.x, k.z * h.x), h.x)) * sign(p.y - h.x), p.z - h.y);
        return min(max(d.x, d.y), 0.0) + length(max(d, 0.0));
    }
};

struct OctogonPrism : ISDF
{
    no_diff float3 center;
    no_diff float r;
    no_diff float h;
    no_diff internal const float3 k = float3(-0.9238795325,  // sqrt(2+sqrt(2))/2
                                              0.3826834323,  // sqrt(2-sqrt(2))/2
                                              0.4142135623); // sqrt(2)-1

    __init(float3 _center, float _r, float _h)
    {
        center = _center;
        r = _r;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        // reflections
        p = abs(p);
        p.xy -= 2.0 * min(dot(float2(k.x, k.y), p.xy), 0.0) * float2(k.x, k.y);
        p.xy -= 2.0 * min(dot(float2(-k.x, k.y), p.xy), 0.0) * float2(-k.x, k.y);
        // polygon side
        p.xy -= float2(clamp(p.x, -k.z * r, k.z * r), r);
        float2 d = float2(length(p.xy) * sign(p.y), p.z - h);
        return min(max(d.x, d.y), 0.0) + length(max(d, 0.0));
    }
};

struct Capsule : ISDF
{
    no_diff float3 center;
    no_diff float3 a;
    no_diff float3 b;
    no_diff float r;

    __init(float3 _center, float3 _a, float3 _b, float _r)
    {
        center = _center;
        a = _a;
        b = _b;
        r = _r;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float3 pa = p - a, ba = b - a;
        float h = clamp(dot(pa, ba) / dot(ba, ba), 0.0, 1.0);
        return length(pa - ba * h) - r;
    }
};

struct RoundConeA : ISDF
{
    no_diff float3 center;
    no_diff float r1;
    no_diff float r2;
    no_diff float h;

    __init(float3 _center, float _r1, float _r2, float _h)
    {
        center = _center;
        r1 = _r1;
        r2 = _r2;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float2 q = float2(length(p.xz), p.y);

        float b = (r1 - r2) / h;
        float a = sqrt(1.0 - b * b);
        float k = dot(q, float2(-b, a));

        if (k < 0.0) return length(q) - r1;
        if (k > a * h) return length(q - float2(0.0, h)) - r2;

        return dot(q, float2(a, b)) - r1;
    }
};

struct RoundConeB : ISDF
{
    no_diff float3 center;
    no_diff float3 a;
    no_diff float3 b;
    no_diff float r1;
    no_diff float r2;

    __init(float3 _center, float3 _a, float3 _b, float _r1, float _r2)
    {
        center = _center;
        a = _a;
        b = _b;
        r1 = _r1;
        r2 = _r2;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        // sampling independent computations (only depend on shape)
        float3 ba = b - a;
        float l2 = dot(ba, ba);
        float rr = r1 - r2;
        float a2 = l2 - rr * rr;
        float il2 = 1.0 / l2;

        // sampling dependant computations
        float3 pa = p - a;
        float y = dot(pa, ba);
        float z = y - l2;
        float x2 = dot2(pa * l2 - ba * y);
        float y2 = y * y * l2;
        float z2 = z * z * l2;

        // single square root!
        float k = sign(rr) * rr * rr * x2;
        if (sign(z) * a2 * z2 > k) return sqrt(x2 + z2) * il2 - r2;
        if (sign(y) * a2 * y2 < k) return sqrt(x2 + y2) * il2 - r1;
        return (sqrt(x2 * a2 * il2) + y * rr) * il2 - r1;
    }
};

struct TriPrism : ISDF
{
    no_diff float3 center;
    no_diff float2 h;

    __init(float3 _center, float2 _h)
    {
        center = _center;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        const float k = sqrt(3.0);
        var hx = h.x * (0.5 * k);
        p.xy /= hx;
        p.x = abs(p.x) - 1.0;
        p.y = p.y + 1.0 / k;
        if (p.x + k * p.y > 0.0) p.xy = float2(p.x - k * p.y, -k * p.x - p.y) / 2.0;
        p.x -= clamp(p.x, -2.0, 0.0);
        float d1 = length(p.xy) * sign(-p.y) * hx;
        float d2 = abs(p.z) - h.y;
        return length(max(float2(d1, d2), 0.0)) + min(max(d1, d2), 0.);
    }
};

struct CylinderV : ISDF
{
    no_diff float3 center;
    no_diff float2 h;

    __init(float3 _center, float2 _h)
    {
        center = _center;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float2 d = abs(float2(length(p.xz), p.y)) - h;
        return min(max(d.x, d.y), 0.0) + length(max(d, 0.0));
    }
};

struct Cylinder : ISDF
{
    no_diff float3 center;
    no_diff float3 a;
    no_diff float3 b;
    no_diff float r;

    __init(float3 _center, float3 _a, float3 _b, float _r)
    {
        center = _center;
        a = _a;
        b = _b;
        r = _r;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float3 pa = p - a;
        float3 ba = b - a;
        float baba = dot(ba, ba);
        float paba = dot(pa, ba);

        float x = length(pa * baba - ba * paba) - r * baba;
        float y = abs(paba - baba * 0.5) - baba * 0.5;
        float x2 = x * x;
        float y2 = y * y * baba;
        float d = (max(x, y) < 0.0) ? -min(x2, y2) : (((x > 0.0) ? x2 : 0.0) + ((y > 0.0) ? y2 : 0.0));
        return sign(d) * sqrt(abs(d)) / baba;
    }
};

struct Cone : ISDF
{
    no_diff float3 center;
    no_diff float2 c;
    no_diff float h;

    __init(float3 _center, float2 _c, float _h)
    {
        center = _center;
        c = _c;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float2 q = h * float2(c.x, -c.y) / c.y;
        float2 w = float2(length(p.xz), p.y);

        float2 a = w - q * clamp(dot(w, q) / dot(q, q), 0.0, 1.0);
        float2 b = w - q * float2(clamp(w.x / q.x, 0.0, 1.0), 1.0);
        float k = sign(q.y);
        float d = min(dot(a, a), dot(b, b));
        float s = max(k * (w.x * q.y - w.y * q.x), k * (w.y - q.y));
        return sqrt(d) * sign(s);
    }
};

struct CappedConeA : ISDF
{
    no_diff float3 center;
    no_diff float h;
    no_diff float r1;
    no_diff float r2;

    __init(float3 _center, float _h, float _r1, float _r2)
    {
        center = _center;
        h = _h;
        r1 = _r1;
        r2 = _r2;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float2 q = float2(length(p.xz), p.y);

        float2 k1 = float2(r2, h);
        float2 k2 = float2(r2 - r1, 2.0 * h);
        float2 ca = float2(q.x - min(q.x, (q.y < 0.0) ? r1 : r2), abs(q.y) - h);
        float2 cb = q - k1 + k2 * clamp(dot(k1 - q, k2) / dot2(k2), 0.0, 1.0);
        float s = (cb.x < 0.0 && ca.y < 0.0) ? -1.0 : 1.0;
        return s * sqrt(min(dot2(ca), dot2(cb)));
    }
};

struct CappedConeB : ISDF
{
    no_diff float3 center;
    no_diff float3 a;
    no_diff float3 b;
    no_diff float ra;
    no_diff float rb;

    __init(float3 _center, float3 _a, float3 _b, float _ra, float _rb)
    {
        center = _center;
        a = _a;
        b = _b;
        ra = _ra;
        rb = _rb;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float rba = rb - ra;
        float baba = dot(b - a, b - a);
        float papa = dot(p - a, p - a);
        float paba = dot(p - a, b - a) / baba;

        float x = sqrt(papa - paba * paba * baba);

        float cax = max(0.0, x - ((paba < 0.5) ? ra : rb));
        float cay = abs(paba - 0.5) - 0.5;

        float k = rba * rba + baba;
        float f = clamp((rba * (x - ra) + paba * baba) / k, 0.0, 1.0);

        float cbx = x - ra - f * rba;
        float cby = paba - f;

        float s = (cbx < 0.0 && cay < 0.0) ? -1.0 : 1.0;

        return s * sqrt(min(cax * cax + cay * cay * baba,
                            cbx * cbx + cby * cby * baba));
    }
};

struct SolidAngle : ISDF
{
    no_diff float3 center;
    no_diff float2 c;
    no_diff float ra;

    __init(float3 _center, float2 _c, float _ra)
    {
        center = _center;
        c = _c;
        ra = _ra;
    }

    // c is the sin/cos of the desired cone angle
    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float2 pop = float2(length(p.xz), p.y);
        float l = length(pop) - ra;
        float m = length(pop - c * clamp(dot(pop, c), 0.0, ra));
        return max(l, m * sign(c.y * pop.x - c.x * pop.y));
    }
};

struct Octahedron : ISDF
{
    no_diff float3 center;
    no_diff float s;

    __init(float3 _center, float _s)
    {
        center = _center;
        s = _s;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p = abs(p);
        float m = p.x + p.y + p.z - s;

    // exact distance
    #if 0
        float3 o = min(3.0*p - m, 0.0);
        o = max(6.0*p - m*2.0 - o*3.0 + (o.x+o.y+o.z), 0.0);
        return length(p - s*o/(o.x+o.y+o.z));
    #endif

    // exact distance
    #if 1
        float3 q;
        if (3.0 * p.x < m) q = p.xyz;
        else if (3.0 * p.y < m) q = p.yzx;
        else if (3.0 * p.z < m) q = p.zxy;
        else return m * 0.57735027;
        float k = clamp(0.5 * (q.z - q.y + s), 0.0, s);
        return length(float3(q.x, q.y - s + k, q.z - k));
    #endif

    // bound, not exact
    #if 0
	    return m*0.57735027;
    #endif
    }
};

struct Pyramid : ISDF
{
    no_diff float3 center;
    no_diff float h;

    __init(float3 _center, float _h)
    {
        center = _center;
        h = _h;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        float m2 = h * h + 0.25;

        // symmetry
        p.xz = abs(p.xz);
        p.xz = (p.z > p.x) ? p.zx : p.xz;
        p.xz -= 0.5;

        // project into face plane (2D)
        float3 q = float3(p.z, h * p.y - 0.5 * p.x, h * p.x + 0.5 * p.y);

        float s = max(-q.x, 0.0);
        float t = clamp((q.y - 0.5 * p.z) / (m2 + 0.25), 0.0, 1.0);

        float a = m2 * (q.x + s) * (q.x + s) + q.y * q.y;
        float b = m2 * (q.x + 0.5 * t) * (q.x + 0.5 * t) + (q.y - m2 * t) * (q.y - m2 * t);

        float d2 = min(q.y, -q.x * m2 - q.y * 0.5) > 0.0 ? 0.0 : min(a, b);

        // recover 3D and scale, and add sign
        return sqrt((d2 + q.z * q.z) / m2) * sign(max(q.z, -p.y));
    }
};

struct Rhombus : ISDF
{
    no_diff float3 center;
    no_diff float la;
    no_diff float lb;
    no_diff float h;
    no_diff float ra;

    __init(float3 _center, float _la, float _lb, float _h, float _ra)
    {
        center = _center;
        la = _la;
        lb = _lb;
        h = _h;
        ra = _ra;
    }

    // la,lb=semi axis, h=height, ra=corner
    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p = p.xzy; // TODO : transformation
        p = abs(p);
        float2 b = float2(la, lb);
        float f = clamp((ndot(b, b - 2.0 * p.xz)) / dot(b, b), -1.0, 1.0);
        float2 q = float2( length(p.xz - 0.5 * b * float2(1.0 - f, 1.0 + f)) * sign(p.x * b.y + p.z * b.x - b.x * b.y) - ra, p.y - h);
        return min(max(q.x, q.y), 0.0) + length(max(q, 0.0));
    }
};

struct Horseshoe : ISDF
{
    no_diff float3 center;
    no_diff float2 c;
    no_diff float r;
    no_diff float le;
    no_diff float2 w;

    __init(float3 _center, float2 _c, float _r, float _le, float2 _w)
    {
        center = _center;
        c = _c;
        r = _r;
        le = _le;
        w = _w;
    }

    [Differentiable]
    float dist(float3 p)
    {
        p = p - center;
        p.x = abs(p.x);
        float l = length(p.xy);
        p.xy = mul(float2x2(-c.x, c.y, c.y, c.x), p.xy);
        p.xy = float2((p.y > 0.0 || p.x > 0.0) ? p.x : l * sign(-c.x),
                (p.x > 0.0) ? p.y : l);
        p.xy = float2(p.x, abs(p.y - r)) - float2(le, 0.0);

        float2 q = float2(length(max(p.xy, 0.0)) + min(0.0, max(p.x, p.y)), p.z);
        float2 d = abs(q) - w;
        return min(max(d.x, d.y), 0.0) + length(max(d, 0.0));
    }
};
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// differentiable ray ----------------------------------------------------------------------------
[Differentiable]
float3x3 setCamera(float3 ro, float3 ta, float cr)
{
    float3 cw = normalize(ta - ro);
    float3 cp = float3(sin(cr), cos(cr), 0.0);
    float3 cu = normalize(cross(cw, cp));
    float3 cv = (cross(cu, cw));
    return float3x3(cu, cv, cw);
}

struct DRay : IDifferentiable
{
    typealias Differential = DRay;

    float3 origin;
    float3 direction;

    [Differentiable]
    __init(float3 o, float3 d)
    {
        origin = o;
        direction = d;
    }

    float3 evalPos(float t)
    {
        return origin + t * direction;
    }
};

[Differentiable]
DRay computeDRay(float2 p, no_diff float time)
{
    float2 mo = float2(0.0, 0.0);

    // camera
    float3 ta = float3(0.25, -0.75, -0.75);
    float3 ro = ta + float3(4.5 * cos(0.1 * time + 7.0 * mo.x), 2.2, 4.5 * sin(0.1 * time + 7.0 * mo.x));

    // camera-to-world transformation
    float3x3 ca = setCamera(ro, ta, 0.0);

    // focal length
    const float fl = 2.5;

    // ray direction
    float3 rd = mul(normalize(float3(p, fl)), ca);

    return DRay(ro, rd);
}

[Differentiable]
float3 planeProjection(float2 p, no_diff float time)
{
    DRay ray = computeDRay(p, time);
    return ray.direction / ray.direction.y;
}
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// scene utilities -------------------------------------------------------------------------------
struct HitObject
{
    float distance;
    float3 normal;
    float4 color;

    __init(float _d, float3 _n, float4 _c)
    {
        distance = _d;
        normal = _n;
        color = _c;
    }

    [mutating] [Differentiable]
    HitObject opU(HitObject a)
    {
        if (a.distance < this.distance)
            this = a;
        return this;
    }
};

struct SceneObject<T : ISDF>
{
    T sdf;
    float material;

    __init(T _s, float _m)
    {
        sdf = _s;
        material = _m;
    }

    float4 evalColor()
    {
        float3 col = 0.2 + 0.2 * sin(material * 2.0 + float3(0.0, 1.0, 2.0));
        float ks = 1.0;
        return float4(col, ks);
    }

    HitObject evalHit(float3 p)
    {
        return HitObject(getDistance(p, sdf), getNormalAD(p, sdf), evalColor());
    }
};
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// scene -----------------------------------------------------------------------------------------
// https://iquilezles.org/articles/checkerfiltering
float checkersGradBox(float2 p, float2 dpdx, float2 dpdy)
{
    // filter kernel
    float2 w = abs(dpdx) + abs(dpdy) + 0.001;
    // analytical integral (box filter)
    float2 i = 2.0 * (abs(fract((p - 0.5 * w) * 0.5) - 0.5) - abs(fract((p + 0.5 * w) * 0.5) - 0.5)) / w;
    // xor pattern
    return saturate(0.5 - 0.5 * i.x * i.y);
}
float glsl_mod(float x, float y)
{
    return x - y * floor(x / y);
}
float checkersTexture(float2 p)
{
    float2 q = floor(p);
    return saturate(glsl_mod(q.x + q.y, 2.0)); // xor pattern
}

static var bbA = Box(float3(-2.0, 0.3, 0.25), float3(0.3, 0.3, 1.0));
static SceneObject<Sphere>         sphere       = { Sphere(float3(-2.0, 0.25, 0.0), 0.25),                      26.9 };
static SceneObject<Rhombus>        rhombus      = { Rhombus(float3(-2.0, 0.25, 1.0), 0.15, 0.25, 0.04, 0.08),   17.0 };

static var bbB = Box(float3(0.0, 0.3, -1.0), float3(0.35, 0.3, 2.5));
static SceneObject<CappedTorus>    cappedTorus  = { CappedTorus(float3(0.0, 0.30, 1.0), float2(0.866025, -0.5), 0.25, 0.05),    25.0 };
static SceneObject<BoxFrame>       boxFrame     = { BoxFrame(float3(0.0, 0.25, 0.0), float3(0.3, 0.25, 0.2), 0.025),            16.9 };
static SceneObject<Cone>           cone         = { Cone(float3(0.0, 0.45, -1.0), float2(0.6, 0.8), 0.45),                      55.0 };
static SceneObject<CappedConeA>    cappedConeA  = { CappedConeA(float3(0.0, 0.25, -2.0), 0.25, 0.25, 0.1),                      13.67 };
static SceneObject<SolidAngle>     solidAngle   = { SolidAngle(float3(0.0, 0.00, -3.0), float2(3, 4) / 5.0, 0.4),               49.13 };

static var bbC = Box(float3(1.0, 0.3, -1.0), float3(0.35, 0.3, 2.5));
static SceneObject<Torus>          torus        = { Torus(float3(1.0, 0.30, 1.0), float2(0.25, 0.05)),                                      7.1 };
static SceneObject<Box>            box          = { Box(float3(1.0, 0.25, 0.0), float3(0.3, 0.25, 0.1)),                                    3.0 };
static SceneObject<Capsule>        capsule      = { Capsule(float3(1.0, 0.00, -1.0), float3(-0.1, 0.1, -0.1), float3(0.2, 0.4, 0.2), 0.1),  31.9 };
static SceneObject<CylinderV>      cylinderV    = { CylinderV(float3(1.0, 0.25, -2.0), float2(0.15, 0.25)),                                 8.0 };
static SceneObject<HexPrism>       hexPrism     = { HexPrism(float3(1.0, 0.2, -3.0), float2(0.2, 0.05)),                                    18.4 };

static var bbD = Box(float3(-1.0, 0.35, -1.0), float3(0.35, 0.35, 2.5));
static SceneObject<Pyramid>        pyramid      = { Pyramid(float3(-1.0, -0.6, -3.0), 1.0),                                                         13.56 };
static SceneObject<Octahedron>     octahedron   = { Octahedron(float3(-1.0, 0.15, -2.0), 0.35),                                                     23.56 };
static SceneObject<TriPrism>       triPrism     = { TriPrism(float3(-1.0, 0.15, -1.0), float2(0.3, 0.05)),                                          43.5 };
static SceneObject<Ellipsoid>      ellipsoid    = { Ellipsoid(float3(-1.0, 0.25, 0.0), float3(0.2, 0.25, 0.05)),                                    43.17 };
static SceneObject<Horseshoe>      horseshoe    = { Horseshoe(float3(-1.0, 0.25, 1.0), float2(cos(1.3), sin(1.3)), 0.2, 0.3, float2(0.03, 0.08)),   11.5 };

static var bbE = Box(float3(2.0, 0.3, -1.0), float3(0.35, 0.3, 2.5));
static SceneObject<OctogonPrism>   octogonPrism = { OctogonPrism(float3(2.0, 0.2, -3.0), 0.2, 0.05),                                                    51.8 };
static SceneObject<Cylinder>       cylinder     = { Cylinder(float3(2.0, 0.14, -2.0), float3(0.1, -0.1, 0.0), float3(-0.2, 0.35, 0.1), 0.08),           31.2 };
static SceneObject<CappedConeB>    cappedConeB  = { CappedConeB(float3(2.0, 0.09, -1.0), float3(0.1, 0.0, 0.0), float3(-0.2, 0.40, 0.1), 0.15, 0.05),   46.1 };
static SceneObject<RoundConeB>     roundConeB   = { RoundConeB(float3(2.0, 0.15, 0.0), float3(0.1, 0.0, 0.0), float3(-0.1, 0.35, 0.1), 0.15, 0.05),     51.7 };
static SceneObject<RoundConeA>     roundConeA   = { RoundConeA(float3(2.0, 0.20, 1.0), 0.2, 0.1, 0.3),                                                  37.0 };

HitObject map(float2 pixPos, float3 pos, float time)
{
#if 0 // no filtering
    float f = checkersTexture(3.0 * pos.xz);

#else // filtering
    DifferentialPair<float2> dpx = diffPair(pixPos, float2(1, 0));
    DifferentialPair<float2> dpy = diffPair(pixPos, float2(0, 1));

    DifferentialPair<float3> duv_dx = fwd_diff(planeProjection)(dpx, time);
    DifferentialPair<float3> duv_dy = fwd_diff(planeProjection)(dpy, time);

    float2 dpdx = duv_dx.d.xz * (2 * _TexelSize.yy);
    float2 dpdy = duv_dy.d.xz * (2 * _TexelSize.yy);

    float f = checkersGradBox(3.0 * pos.xz, 3.0 * dpdx, 3.0 * dpdy);

#endif

    float4 planeColor = 0.15 + f * 0.05;
    planeColor.w = 0.4; // ks

    HitObject h = HitObject(pos.y, float3(0, 1, 0), planeColor);

    if (getDistance(pos, bbA) < h.distance)
    {
        h.opU(sphere.evalHit(pos));
        h.opU(rhombus.evalHit(pos));
    }
    if (getDistance(pos, bbB) < h.distance)
    {
        h.opU(cappedTorus.evalHit(pos));
        h.opU(boxFrame.evalHit(pos));
        h.opU(cone.evalHit(pos));
        h.opU(cappedConeA.evalHit(pos));
        h.opU(solidAngle.evalHit(pos));
    }
    if (getDistance(pos, bbC) < h.distance)
    {
        h.opU(torus.evalHit(pos));
        h.opU(box.evalHit(pos));
        h.opU(capsule.evalHit(pos));
        h.opU(cylinderV.evalHit(pos));
        h.opU(hexPrism.evalHit(pos));
    }
    if (getDistance(pos, bbD) < h.distance)
    {
        h.opU(pyramid.evalHit(pos));
        h.opU(octahedron.evalHit(pos));
        h.opU(triPrism.evalHit(pos));
        h.opU(ellipsoid.evalHit(pos));
        h.opU(horseshoe.evalHit(pos));
    }
    if (getDistance(pos, bbE) < h.distance)
    {
        h.opU(octogonPrism.evalHit(pos));
        h.opU(cylinder.evalHit(pos));
        h.opU(cappedConeB.evalHit(pos));
        h.opU(roundConeB.evalHit(pos));
        h.opU(roundConeA.evalHit(pos));
    }

    return h;
}
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// iq's methods ----------------------------------------------------------------------------------

// https://iquilezles.org/articles/rmshadows
float calcSoftshadow(float2 pixPos, DRay ray, float mint, float tmax, float time)
{
    // bounding volume
    float tp = (0.8 - ray.origin.y) / ray.direction.y; if (tp > 0.0) tmax = min(tmax, tp);

    float res = 1.0;
    float t = mint;
    for (int i = 0; i < 24; i++)
    {
        float h = map(pixPos, ray.evalPos(t), time).distance;
        float s = clamp(8.0 * h / t, 0.0, 1.0);
        res = min(res, s);
        t += clamp(h, 0.01, 0.2);
        if (res < 0.004 || t > tmax) break;
    }
    res = clamp(res, 0.0, 1.0);
    return res * res * (3.0 - 2.0 * res);
}

// https://iquilezles.org/articles/nvscene2008/rwwtt.pdf
float calcAO(float2 pixPos, DRay ray, float time)
{
    float occ = 0.0;
    float sca = 1.0;
    for (int i = 0; i < 5; i++)
    {
        float h = 0.01 + 0.12 * float(i) / 4.0;
        float d = map(pixPos, ray.evalPos(h), time).distance;
        occ += (h - d) * sca;
        sca *= 0.95;
        if (occ > 0.35) break;
    }
    return clamp(1.0 - 3.0 * occ, 0.0, 1.0) * (0.5 + 0.5 * ray.direction.y);
}
// -----------------------------------------------------------------------------------------------

// -----------------------------------------------------------------------------------------------
// main ------------------------------------------------------------------------------------------
static const int NUM_STEP = 100;

[shader("compute")]
[numthreads(32, 32, 1)]
void iq_scene(uint3 threadId: SV_DispatchThreadID)
{
    float time = 32.0 + _Time * 1.5;

    // pixel coordinates
    float2 pixPos = (2 * (threadId.xy + 0.5) - _TexelSize.zw) * _TexelSize.y;
    DRay ray = computeDRay(pixPos, time);

    float3 color = 0.0;
    HitObject res = HitObject(-1.0, 0.0, -1.0);

    float tmin = 1.0;
    float tmax = 20.0;

    // raytrace floor plane
    float tp1 = (0.0 - ray.origin.y) / ray.direction.y;
    if (tp1 > 0.0)
    {
        tmax = min(tmax, tp1);
        float4 bgColor = float4(float3(0.7, 0.7, 0.9) - max(ray.direction.y, 0.0) * 0.3, 1.0);
        res = HitObject(tp1, float3(0, 0, 0), bgColor);
    }

    // raymarch primitives
    float t = tmin;
    for (int i = 0; i < NUM_STEP && t < tmax; i++)
    {
        HitObject h = map(pixPos, ray.evalPos(t), time);

        if (abs(h.distance) < (0.0001 * t))
        {
            res = h;
            res.distance = t;
            break;
        }
        t += h.distance;
    }

    { // evaluate color
        float3 pos = ray.evalPos(res.distance);
        float3 nor = res.normal;
        float3 refl = reflect(ray.direction, nor);

        // material
        float3 col = res.color.rgb;
        float ks = res.color.w;

        // lighting
        DRay aoRay = DRay(pos, nor);
        float occ = calcAO(pixPos, aoRay, time);

        float3 lin = float3(0.0);
        // sun
        {
            DRay lightingRay = DRay(pos, normalize(float3(-0.5, 0.4, -0.6)));
            float3 hal = normalize(lightingRay.direction - ray.direction);
            float dif = clamp(dot(nor, lightingRay.direction), 0.0, 1.0);
                  dif *= calcSoftshadow(pixPos, lightingRay, 0.02, 2.5, time);
            float spe = pow(clamp(dot(nor, hal), 0.0, 1.0), 16.0);
                  spe *= dif;
                  spe *= 0.04 + 0.96 * pow(clamp(1.0 - dot(hal, lightingRay.direction), 0.0, 1.0), 5.0);
            lin += col * 2.20 * dif * float3(1.30, 1.00, 0.70);
            lin += 5.00 * spe * float3(1.30, 1.00, 0.70) * ks;
        }
        // sky
        {
            DRay skyRay = DRay(pos, refl);
            float dif = sqrt(clamp(0.5 + 0.5 * nor.y, 0.0, 1.0));
                  dif *= occ;
            float spe = smoothstep(-0.2, 0.2, refl.y);
                  spe *= dif;
                  spe *= 0.04 + 0.96 * pow(clamp(1.0 + dot(nor, ray.direction), 0.0, 1.0), 5.0);
                  spe *= calcSoftshadow(pixPos, skyRay, 0.02, 2.5, time);
            lin += col * 0.60 * dif * float3(0.40, 0.60, 1.15);
            lin += 2.00 * spe * float3(0.40, 0.60, 1.30) * ks;
        }
        // back
        {
            float dif = clamp(dot(nor, normalize(float3(0.5, 0.0, 0.6))), 0.0, 1.0) * clamp(1.0 - ray.origin.y, 0.0, 1.0);
                  dif *= occ;
            lin += col * 0.55 * dif * float3(0.25, 0.25, 0.25);
        }
        // sss
        {
            float dif = pow(clamp(1.0 + dot(nor, ray.direction), 0.0, 1.0), 2.0);
                  dif *= occ;
            lin += col * 0.25 * dif * float3(1.00, 1.00, 1.00);
        }

        color = lin;
        color = lerp(color, float3(0.7, 0.7, 0.9), 1.0 - exp(-0.0001 * pow(res.distance, 3.0)));
    }

    result[threadId.xy] = float4(color, 1.0);
}
// -----------------------------------------------------------------------------------------------
この記事をシェアする

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です