|
| 1 | +# Best Practices for Scholar Contributions |
| 2 | + |
| 3 | +This document captures key patterns and requirements for implementing algorithms in Scholar. |
| 4 | + |
| 5 | +## Core Principle: JIT/GPU Compatibility |
| 6 | + |
| 7 | +**All algorithms must be JIT-compilable and GPU-compatible.** |
| 8 | + |
| 9 | +Nx relies on multi-stage compilation. When `defn` is executed, it doesn't have |
| 10 | +actually values, only references to tensors shapes and types. The `defn` execution |
| 11 | +then builds a numeric graph, which is lowered just-in-time (JIT) to CPUs and GPUs. |
| 12 | + |
| 13 | +## Required Patterns |
| 14 | + |
| 15 | +### 1. Use `deftransform` for Entry Points |
| 16 | + |
| 17 | +The main entry point must be a `deftransform` that simply unpacks |
| 18 | +options and immediately calls a `defnp`. |
| 19 | + |
| 20 | +```elixir |
| 21 | +# GOOD - JIT compatible |
| 22 | +deftransform fit(a, b, fun, opts \\ []) do |
| 23 | + opts = NimbleOptions.validate!(opts, @opts_schema) |
| 24 | + fit_n(a, b, fun, opts[:tol], opts[:maxiter]) |
| 25 | +end |
| 26 | +``` |
| 27 | + |
| 28 | +Do not perform `Nx` operations inside `deftransform` (not even during validation). |
| 29 | +Make sure all values that can be tensors are given as explicit arguments to the |
| 30 | +`defnp` function, and invoke `Nx` operations there. |
| 31 | + |
| 32 | +### 2. Expose Required Parameters as Function Arguments |
| 33 | + |
| 34 | +Don't bury required parameters in options - expose them as explicit arguments: |
| 35 | + |
| 36 | +```elixir |
| 37 | +# GOOD - bounds as explicit args |
| 38 | +deftransform fit(a, b, fun, opts \\ []) |
| 39 | + |
| 40 | +# BAD - bounds buried in options |
| 41 | +deftransform fit(fun, opts) do |
| 42 | + {a, b} = opts[:bracket] |
| 43 | +``` |
| 44 | + |
| 45 | +**Why?** Note the `deftransform -> defn` conversion will convert input types to Nx tensors when crossing the `def -> defn` boundary, eliminating the need for custom validation logic. |
| 46 | + |
| 47 | +### 3. Use `Nx.select` for Branch-Free Conditionals |
| 48 | + |
| 49 | +For complex multi-way conditionals, use nested `Nx.select`: |
| 50 | + |
| 51 | +```elixir |
| 52 | +# Four-way conditional |
| 53 | +result = Nx.select( |
| 54 | + cond1, |
| 55 | + value1, |
| 56 | + Nx.select( |
| 57 | + cond2, |
| 58 | + value2, |
| 59 | + Nx.select(cond3, value3, value4) |
| 60 | + ) |
| 61 | +) |
| 62 | +``` |
| 63 | + |
| 64 | +### 4. Use `while` Loop (Not Recursion) |
| 65 | + |
| 66 | +```elixir |
| 67 | +# GOOD - while loop |
| 68 | +{final_state, _} = |
| 69 | + while {state = initial_state, {tol, maxiter}}, |
| 70 | + state.iter < maxiter and state.b - state.a >= tol do |
| 71 | + # Update state |
| 72 | + new_state = %{state | iter: state.iter + 1, ...} |
| 73 | + {new_state, {tol, maxiter}} |
| 74 | + end |
| 75 | + |
| 76 | +# BAD - recursive call |
| 77 | +defnp loop(fun, state, tol, maxiter) do |
| 78 | + if converged?(state) do |
| 79 | + state |
| 80 | + else |
| 81 | + loop(fun, update(state), tol, maxiter) |
| 82 | + end |
| 83 | +end |
| 84 | +``` |
| 85 | + |
| 86 | +**Why?** Remember when `defnp` runs, we don't have runtime values, |
| 87 | +so most loops can't terminate within `defnp`. And for the loops that can |
| 88 | +terminate (the condition is known statically), doing the loop in `defnp` |
| 89 | +means that you are generating the graph recursively, which can then become |
| 90 | +large and take a long time to compile. |
| 91 | + |
| 92 | +### 5. Never Use `Nx.to_number` in `defn` |
| 93 | + |
| 94 | +All computation must stay as tensors for JIT compilation: |
| 95 | + |
| 96 | +```elixir |
| 97 | +# GOOD - all tensor operations |
| 98 | +converged = state.b - state.a < tol |
| 99 | + |
| 100 | +# BAD - converts to Elixir number (will crash inside JIT) |
| 101 | +converged = Nx.to_number(state.b) - Nx.to_number(state.a) < Nx.to_number(tol) |
| 102 | +``` |
| 103 | + |
| 104 | +### 6. Use Unsigned Types for Non-Negative Counters |
| 105 | + |
| 106 | +```elixir |
| 107 | +initial_state = %{ |
| 108 | + iter: Nx.u32(0), # u32 for iteration count |
| 109 | + f_evals: Nx.u32(2) # u32 for function evaluation count |
| 110 | +} |
| 111 | +``` |
| 112 | + |
| 113 | +### 7. Let Users Control Precision via Input Types |
| 114 | + |
| 115 | +Don't force type conversions - let the tensor type propagate from inputs: |
| 116 | + |
| 117 | +```elixir |
| 118 | +# GOOD - let user decide precision |
| 119 | +defn minimize(a, b, fun, opts \\ []) do |
| 120 | + # a and b types propagate through computation |
| 121 | +end |
| 122 | + |
| 123 | +# BAD - forcing f64 |
| 124 | +a = Nx.tensor(a, type: :f64) |
| 125 | +``` |
| 126 | + |
| 127 | +### 8. Use Module Constants Directly |
| 128 | + |
| 129 | +```elixir |
| 130 | +# GOOD - use @attr directly in defn |
| 131 | +@phi 0.6180339887498949 |
| 132 | + |
| 133 | +defnp minimize_n(a, b, fun, tol, maxiter) do |
| 134 | + c = b - @phi * (b - a) |
| 135 | +end |
| 136 | + |
| 137 | +# BAD - wrapping in tensor |
| 138 | +defnp minimize_n(a, b, fun, tol, maxiter) do |
| 139 | + phi = Nx.tensor(@phi) |
| 140 | + c = b - phi * (b - a) |
| 141 | +end |
| 142 | +``` |
| 143 | + |
| 144 | +### 9. Self-Contain Modules |
| 145 | + |
| 146 | +Keep NimbleOptions validation in the same module - don't create wrapper modules: |
| 147 | + |
| 148 | +```elixir |
| 149 | +defmodule Scholar.Optimize.Brent do |
| 150 | + opts = [ |
| 151 | + tol: [...], |
| 152 | + maxiter: [...] |
| 153 | + ] |
| 154 | + |
| 155 | + @opts_schema NimbleOptions.new!(opts) |
| 156 | + |
| 157 | + # Validation happens here, not in a separate module |
| 158 | +end |
| 159 | +``` |
| 160 | + |
| 161 | +## Test Requirements |
| 162 | + |
| 163 | +Every module must include: |
| 164 | + |
| 165 | +1. **Basic functionality tests** - Verify correct results on standard test functions |
| 166 | +2. **Option handling tests** - Test options |
| 167 | +3. **JIT compatibility test** - Critical! For example, if the algorithm has a `minimize` function, ensure it can be invoked when wrapped in `jit_apply/4`: |
| 168 | + |
| 169 | +```elixir |
| 170 | +test "works with jit_apply" do |
| 171 | + fun = fn x -> Nx.pow(Nx.subtract(x, 3), 2) end |
| 172 | + opts = [tol: 1.0e-5, maxiter: 500] |
| 173 | + result = Nx.Defn.jit_apply(&AlgorithmName.minimize/4, [0.0, 5.0, fun, opts]) |
| 174 | + assert Nx.to_number(result.converged) == 1 |
| 175 | +end |
| 176 | +``` |
| 177 | + |
| 178 | +4. **Tensor bounds test** - Accept both numbers and tensors |
| 179 | +5. **Precision test** - Higher precision with f64 bounds |
| 180 | + |
| 181 | +## Validation Against SciPy |
| 182 | + |
| 183 | +When implementing algorithms, validate results against SciPy: |
| 184 | + |
| 185 | +```python |
| 186 | +from scipy.optimize import minimize_scalar |
| 187 | + |
| 188 | +result = minimize_scalar(func, bracket=(a, b), method='brent') |
| 189 | +print(f"x: {result.x}, f(x): {result.fun}, iterations: {result.nit}") |
| 190 | +``` |
| 191 | + |
| 192 | +Use these reference values in tests with appropriate tolerance. |
| 193 | + |
| 194 | +## References |
| 195 | + |
| 196 | +- PR #323: https://github.com/elixir-nx/scholar/pull/314 (RobustScaler) |
| 197 | +- PR #327: https://github.com/elixir-nx/scholar/pull/327 (GoldenSection) |
0 commit comments