Morgan Funtowicz
commited on
Commit
·
a51f009
1
Parent(s):
7ff080c
feat(parakeet): use bf16 for inference
Browse files- handler.py +2 -0
handler.py
CHANGED
@@ -3,6 +3,7 @@ import zlib
|
|
3 |
from functools import partial
|
4 |
from io import BytesIO
|
5 |
|
|
|
6 |
from hfendpoints.openai import Context, run
|
7 |
from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
|
8 |
TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
|
@@ -40,6 +41,7 @@ class NemoAsrHandler(Handler):
|
|
40 |
def __init__(self, config: EndpointConfig):
|
41 |
logger.info(config.repository)
|
42 |
self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
|
|
|
43 |
|
44 |
async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
|
45 |
with logger.contextualize(request_id=ctx.request_id):
|
|
|
3 |
from functools import partial
|
4 |
from io import BytesIO
|
5 |
|
6 |
+
import torch
|
7 |
from hfendpoints.openai import Context, run
|
8 |
from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
|
9 |
TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
|
|
|
41 |
def __init__(self, config: EndpointConfig):
|
42 |
logger.info(config.repository)
|
43 |
self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
|
44 |
+
self._model = self._model.to(torch.bfloat16)
|
45 |
|
46 |
async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
|
47 |
with logger.contextualize(request_id=ctx.request_id):
|