Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions bigframes/ml/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Metrics functions for evaluating models. This module is styled after
scikit-learn's metrics module: https://scikit-learn.org/stable/modules/metrics.html."""

from __future__ import annotations

import inspect
import typing
from typing import Tuple, Union
from typing import Literal, overload, Tuple, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.sklearn.metrics._classification as vendored_metrics_classification
Expand Down Expand Up @@ -259,31 +261,64 @@ def recall_score(
recall_score.__doc__ = inspect.getdoc(vendored_metrics_classification.recall_score)


@overload
def precision_score(
y_true: Union[bpd.DataFrame, bpd.Series],
y_pred: Union[bpd.DataFrame, bpd.Series],
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
average: typing.Optional[str] = "binary",
pos_label: int | float | bool | str = ...,
average: Literal["binary"] = ...,
) -> float:
...


@overload
def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = ...,
average: None = ...,
) -> pd.Series:
# TODO(ashleyxu): support more average type, default to "binary"
if average is not None:
raise NotImplementedError(
f"Only average=None is supported. {constants.FEEDBACK_LINK}"
)
...


def precision_score(
y_true: bpd.DataFrame | bpd.Series,
y_pred: bpd.DataFrame | bpd.Series,
*,
pos_label: int | float | bool | str = 1,
average: Literal["binary"] | None = "binary",
) -> pd.Series | float:
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)

is_accurate = y_true_series == y_pred_series
if average is None:
return _precision_score_per_class(y_true_series, y_pred_series)

if average == "binary":
return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label)

raise NotImplementedError(
f"Unsupported 'average' param value: {average}. {constants.FEEDBACK_LINK}"
)


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)


def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series:
is_accurate = y_true == y_pred
unique_labels = (
bpd.concat([y_true_series, y_pred_series], join="outer")
bpd.concat([y_true, y_pred], join="outer")
.drop_duplicates()
.sort_values(inplace=False)
)
index = unique_labels.to_list()

precision = (
is_accurate.groupby(y_pred_series).sum()
/ is_accurate.groupby(y_pred_series).count()
is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count()
).to_pandas()

precision_score = pd.Series(0, index=index)
Expand All @@ -293,9 +328,32 @@ def precision_score(
return precision_score


precision_score.__doc__ = inspect.getdoc(
vendored_metrics_classification.precision_score
)
def _precision_score_binary_pos_only(
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
) -> float:
if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may create extra queries with y_true.drop_duplicates().to_list() in line 340. We may want to merge them.

Can you take a look at how many queries are created when running this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the result: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB. it feels weird because no query jobs are printed out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Local execution? @TrevorBergeron

raise ValueError(
"Target is multiclass but average='binary'. Please choose another average setting."
)

total_labels = set(
y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably avoid drop_duplicates, it has overhead from trying to preserve ordering, try unique(keep_order=False) instead. Also try to minimize query count

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code updated. This is the execution output: https://screenshot.googleplex.com/9aFGAUSHzuPDPtB.

It's weird that no query job links are provided.


if len(total_labels) != 2:
raise ValueError(
"Target is multiclass but average='binary'. Please choose another average setting."
)

if pos_label not in total_labels:
raise ValueError(
f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}"
)

target_elem_idx = y_pred == pos_label
is_accurate = y_pred[target_elem_idx] == y_true[target_elem_idx]

return is_accurate.sum() / is_accurate.count()


def f1_score(
Expand Down
51 changes: 51 additions & 0 deletions tests/system/small/ml/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,57 @@ def test_precision_score_series(session):
)


@pytest.mark.parametrize(
("pos_label", "expected_score"),
[
("a", 1 / 3),
("b", 0),
],
)
def test_precision_score_binary(session, pos_label, expected_score):
pd_df = pd.DataFrame(
{
"y_true": ["a", "a", "a", "b", "b"],
"y_pred": ["b", "b", "a", "a", "a"],
}
)
df = session.read_pandas(pd_df)

precision_score = metrics.precision_score(
df["y_true"], df["y_pred"], average="binary", pos_label=pos_label
)

assert precision_score == pytest.approx(expected_score)


@pytest.mark.parametrize(
("y_true", "y_pred", "pos_label"),
[
pytest.param(
pd.Series([1, 2, 3]), pd.Series([1, 0]), 1, id="y_true-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2, 3]), 1, id="y_pred-non-binary-label"
),
pytest.param(
pd.Series([1, 0]), pd.Series([1, 2]), 1, id="combined-non-binary-label"
),
pytest.param(pd.Series([1, 0]), pd.Series([1, 0]), 2, id="invalid-pos_label"),
],
)
def test_precision_score_binary_invalid_input_raise_error(
session, y_true, y_pred, pos_label
):

bf_y_true = session.read_pandas(y_true)
bf_y_pred = session.read_pandas(y_pred)

with pytest.raises(ValueError):
metrics.precision_score(
bf_y_true, bf_y_pred, average="binary", pos_label=pos_label
)


def test_f1_score(session):
pd_df = pd.DataFrame(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def precision_score(
default='binary'
This parameter is required for multiclass/multilabel targets.
Possible values are 'None', 'micro', 'macro', 'samples', 'weighted', 'binary'.
Only average=None is supported.
Only None and 'binary' is supported.

Returns:
precision: float (if average is not None) or Series of float of shape \
Expand Down