Skip to content

Missing some simple matrix algebraic simplifications #1479

@ricardoV94

Description

@ricardoV94

Description

Missing:

  • A@B + A@C = A@(B+C) (one less matmul)
  • s*A @ B = s*(A@B) (which can be done by a single gemm routine)
from pytensor.graph import rewrite_graph import pytensor.tensor as pt A,B,C = pt.matrices("ABC") s = pt.scalar("s") o1 = A@B + A@C rewrite_graph(o1, include=("fast_run",), exclude=("inplace",)).dprint() print() o2 = (s*A) @ B rewrite_graph(o2, include=("fast_run",), exclude=("inplace",)).dprint()
Gemm{no_inplace} [id A] ├─ Dot22 [id B] │ ├─ A [id C] │ └─ B [id D] ├─ 1.0 [id E] ├─ A [id C] ├─ C [id F] └─ 1.0 [id E] Dot22 [id A] ├─ Mul [id B] │ ├─ ExpandDims{axes=[0, 1]} [id C] │ │ └─ s [id D] │ └─ A [id E] └─ B [id F] 

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions