- Notifications
You must be signed in to change notification settings - Fork 150
Open
Labels
Description
Description
Missing:
A@B + A@C = A@(B+C)(one less matmul)s*A @ B = s*(A@B)(which can be done by a single gemm routine)
from pytensor.graph import rewrite_graph import pytensor.tensor as pt A,B,C = pt.matrices("ABC") s = pt.scalar("s") o1 = A@B + A@C rewrite_graph(o1, include=("fast_run",), exclude=("inplace",)).dprint() print() o2 = (s*A) @ B rewrite_graph(o2, include=("fast_run",), exclude=("inplace",)).dprint()Gemm{no_inplace} [id A] ├─ Dot22 [id B] │ ├─ A [id C] │ └─ B [id D] ├─ 1.0 [id E] ├─ A [id C] ├─ C [id F] └─ 1.0 [id E] Dot22 [id A] ├─ Mul [id B] │ ├─ ExpandDims{axes=[0, 1]} [id C] │ │ └─ s [id D] │ └─ A [id E] └─ B [id F]