|
59 | 59 | "MultivariateNormalProposal", |
60 | 60 | ] |
61 | 61 |
|
62 | | -from pymc.util import get_value_vars_from_user_vars |
| 62 | +from pymc.util import RandomGenerator, get_value_vars_from_user_vars |
63 | 63 |
|
64 | 64 | # Available proposal distributions for Metropolis |
65 | 65 |
|
@@ -302,7 +302,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: |
302 | 302 | accept_rate = self.delta_logp(q, q0d) |
303 | 303 | q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng) |
304 | 304 | self.accept_rate_iter = accept_rate |
305 | | - self.accepted_iter = accepted |
| 305 | + self.accepted_iter[0] = accepted |
306 | 306 | self.accepted_sum += accepted |
307 | 307 |
|
308 | 308 | self.steps_until_tune -= 1 |
@@ -622,14 +622,16 @@ class CategoricalGibbsMetropolis(ArrayStep): |
622 | 622 |
|
623 | 623 | _state_class = CategoricalGibbsMetropolisState |
624 | 624 |
|
625 | | - def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None): |
| 625 | + def __init__( |
| 626 | + self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None |
| 627 | + ): |
626 | 628 | model = pm.modelcontext(model) |
627 | 629 |
|
628 | 630 | vars = get_value_vars_from_user_vars(vars, model) |
629 | 631 |
|
630 | 632 | initial_point = model.initial_point() |
631 | 633 |
|
632 | | - dimcats = [] |
| 634 | + dimcats: list[tuple[int, int]] = [] |
633 | 635 | # The above variable is a list of pairs (aggregate dimension, number |
634 | 636 | # of categories). For example, if vars = [x, y] with x being a 2-D |
635 | 637 | # variable with M categories and y being a 3-D variable with N |
@@ -665,10 +667,10 @@ def __init__(self, vars, proposal="uniform", order="random", model=None, rng=Non |
665 | 667 | self.dimcats = [dimcats[j] for j in order] |
666 | 668 |
|
667 | 669 | if proposal == "uniform": |
668 | | - self.astep = self.astep_unif |
| 670 | + self.astep = self.astep_unif # type: ignore[assignment] |
669 | 671 | elif proposal == "proportional": |
670 | 672 | # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. |
671 | | - self.astep = self.astep_prop |
| 673 | + self.astep = self.astep_prop # type: ignore[assignment] |
672 | 674 | else: |
673 | 675 | raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") |
674 | 676 |
|
|
0 commit comments