Skip to content

Commit 7c69c26

Browse files
committed
Clean up AGENTS.md a bit
1 parent 4c7a699 commit 7c69c26

File tree

4 files changed

+201
-303
lines changed

4 files changed

+201
-303
lines changed

AGENTS.md

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Best Practices for Scholar Contributions
2+
3+
This document captures key patterns and requirements for implementing algorithms in Scholar.
4+
5+
## Core Principle: JIT/GPU Compatibility
6+
7+
**All algorithms must be JIT-compilable and GPU-compatible.**
8+
9+
Nx relies on multi-stage compilation. When `defn` is executed, it doesn't have
10+
actually values, only references to tensors shapes and types. The `defn` execution
11+
then builds a numeric graph, which is lowered just-in-time (JIT) to CPUs and GPUs.
12+
13+
## Required Patterns
14+
15+
### 1. Use `deftransform` for Entry Points
16+
17+
The main entry point must be a `deftransform` that simply unpacks
18+
options and immediately calls a `defnp`.
19+
20+
```elixir
21+
# GOOD - JIT compatible
22+
deftransform fit(a, b, fun, opts \\ []) do
23+
opts = NimbleOptions.validate!(opts, @opts_schema)
24+
fit_n(a, b, fun, opts[:tol], opts[:maxiter])
25+
end
26+
```
27+
28+
Do not perform `Nx` operations inside `deftransform` (not even during validation).
29+
Make sure all values that can be tensors are given as explicit arguments to the
30+
`defnp` function, and invoke `Nx` operations there.
31+
32+
### 2. Expose Required Parameters as Function Arguments
33+
34+
Don't bury required parameters in options - expose them as explicit arguments:
35+
36+
```elixir
37+
# GOOD - bounds as explicit args
38+
deftransform fit(a, b, fun, opts \\ [])
39+
40+
# BAD - bounds buried in options
41+
deftransform fit(fun, opts) do
42+
{a, b} = opts[:bracket]
43+
```
44+
45+
**Why?** Note the `deftransform -> defn` conversion will convert input types to Nx tensors when crossing the `def -> defn` boundary, eliminating the need for custom validation logic.
46+
47+
### 3. Use `Nx.select` for Branch-Free Conditionals
48+
49+
For complex multi-way conditionals, use nested `Nx.select`:
50+
51+
```elixir
52+
# Four-way conditional
53+
result = Nx.select(
54+
cond1,
55+
value1,
56+
Nx.select(
57+
cond2,
58+
value2,
59+
Nx.select(cond3, value3, value4)
60+
)
61+
)
62+
```
63+
64+
### 4. Use `while` Loop (Not Recursion)
65+
66+
```elixir
67+
# GOOD - while loop
68+
{final_state, _} =
69+
while {state = initial_state, {tol, maxiter}},
70+
state.iter < maxiter and state.b - state.a >= tol do
71+
# Update state
72+
new_state = %{state | iter: state.iter + 1, ...}
73+
{new_state, {tol, maxiter}}
74+
end
75+
76+
# BAD - recursive call
77+
defnp loop(fun, state, tol, maxiter) do
78+
if converged?(state) do
79+
state
80+
else
81+
loop(fun, update(state), tol, maxiter)
82+
end
83+
end
84+
```
85+
86+
**Why?** Remember when `defnp` runs, we don't have runtime values,
87+
so most loops can't terminate within `defnp`. And for the loops that can
88+
terminate (the condition is known statically), doing the loop in `defnp`
89+
means that you are generating the graph recursively, which can then become
90+
large and take a long time to compile.
91+
92+
### 5. Never Use `Nx.to_number` in `defn`
93+
94+
All computation must stay as tensors for JIT compilation:
95+
96+
```elixir
97+
# GOOD - all tensor operations
98+
converged = state.b - state.a < tol
99+
100+
# BAD - converts to Elixir number (will crash inside JIT)
101+
converged = Nx.to_number(state.b) - Nx.to_number(state.a) < Nx.to_number(tol)
102+
```
103+
104+
### 6. Use Unsigned Types for Non-Negative Counters
105+
106+
```elixir
107+
initial_state = %{
108+
iter: Nx.u32(0), # u32 for iteration count
109+
f_evals: Nx.u32(2) # u32 for function evaluation count
110+
}
111+
```
112+
113+
### 7. Let Users Control Precision via Input Types
114+
115+
Don't force type conversions - let the tensor type propagate from inputs:
116+
117+
```elixir
118+
# GOOD - let user decide precision
119+
defn minimize(a, b, fun, opts \\ []) do
120+
# a and b types propagate through computation
121+
end
122+
123+
# BAD - forcing f64
124+
a = Nx.tensor(a, type: :f64)
125+
```
126+
127+
### 8. Use Module Constants Directly
128+
129+
```elixir
130+
# GOOD - use @attr directly in defn
131+
@phi 0.6180339887498949
132+
133+
defnp minimize_n(a, b, fun, tol, maxiter) do
134+
c = b - @phi * (b - a)
135+
end
136+
137+
# BAD - wrapping in tensor
138+
defnp minimize_n(a, b, fun, tol, maxiter) do
139+
phi = Nx.tensor(@phi)
140+
c = b - phi * (b - a)
141+
end
142+
```
143+
144+
### 9. Self-Contain Modules
145+
146+
Keep NimbleOptions validation in the same module - don't create wrapper modules:
147+
148+
```elixir
149+
defmodule Scholar.Optimize.Brent do
150+
opts = [
151+
tol: [...],
152+
maxiter: [...]
153+
]
154+
155+
@opts_schema NimbleOptions.new!(opts)
156+
157+
# Validation happens here, not in a separate module
158+
end
159+
```
160+
161+
## Test Requirements
162+
163+
Every module must include:
164+
165+
1. **Basic functionality tests** - Verify correct results on standard test functions
166+
2. **Option handling tests** - Test options
167+
3. **JIT compatibility test** - Critical! For example, if the algorithm has a `minimize` function, ensure it can be invoked when wrapped in `jit_apply/4`:
168+
169+
```elixir
170+
test "works with jit_apply" do
171+
fun = fn x -> Nx.pow(Nx.subtract(x, 3), 2) end
172+
opts = [tol: 1.0e-5, maxiter: 500]
173+
result = Nx.Defn.jit_apply(&AlgorithmName.minimize/4, [0.0, 5.0, fun, opts])
174+
assert Nx.to_number(result.converged) == 1
175+
end
176+
```
177+
178+
4. **Tensor bounds test** - Accept both numbers and tensors
179+
5. **Precision test** - Higher precision with f64 bounds
180+
181+
## Validation Against SciPy
182+
183+
When implementing algorithms, validate results against SciPy:
184+
185+
```python
186+
from scipy.optimize import minimize_scalar
187+
188+
result = minimize_scalar(func, bracket=(a, b), method='brent')
189+
print(f"x: {result.x}, f(x): {result.fun}, iterations: {result.nit}")
190+
```
191+
192+
Use these reference values in tests with appropriate tolerance.
193+
194+
## References
195+
196+
- PR #323: https://github.com/elixir-nx/scholar/pull/314 (RobustScaler)
197+
- PR #327: https://github.com/elixir-nx/scholar/pull/327 (GoldenSection)

0 commit comments

Comments
 (0)