File size: 1,078 Bytes
3de7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
"""F1 Score metric.
This is added for convenience.
"""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Literal
from torchmetrics.classification import BinaryF1Score
logger = logging.getLogger(__name__)
class F1Score(BinaryF1Score):
"""This is a wrapper around torchmetrics' BinaryF1Score.
The idea behind this is to retain the current configuration otherwise the one from
torchmetrics requires ``task`` as a parameter.
"""
def __init__(
self,
threshold: float = 0.5,
multidim_average: Literal["global"] | Literal["samplewise"] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any, # noqa: ANN401
) -> None:
super().__init__(threshold, multidim_average, ignore_index, validate_args, **kwargs)
logger.warning(
"F1Score class exists for backwards compatibility. It will be removed in v1.1."
" Please use BinaryF1Score from torchmetrics instead",
)
|