File size: 1,078 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""F1 Score metric.

This is added for convenience.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from typing import Any, Literal

from torchmetrics.classification import BinaryF1Score

logger = logging.getLogger(__name__)


class F1Score(BinaryF1Score):
    """This is a wrapper around torchmetrics' BinaryF1Score.

    The idea behind this is to retain the current configuration otherwise the one from
    torchmetrics requires ``task`` as a parameter.
    """

    def __init__(
        self,
        threshold: float = 0.5,
        multidim_average: Literal["global"] | Literal["samplewise"] = "global",
        ignore_index: int | None = None,
        validate_args: bool = True,
        **kwargs: Any,  # noqa: ANN401
    ) -> None:
        super().__init__(threshold, multidim_average, ignore_index, validate_args, **kwargs)
        logger.warning(
            "F1Score class exists for backwards compatibility. It will be removed in v1.1."
            " Please use BinaryF1Score from torchmetrics instead",
        )