Morgan Funtowicz commited on
Commit
a51f009
·
1 Parent(s): 7ff080c

feat(parakeet): use bf16 for inference

Browse files
Files changed (1) hide show
  1. handler.py +2 -0
handler.py CHANGED
@@ -3,6 +3,7 @@ import zlib
3
  from functools import partial
4
  from io import BytesIO
5
 
 
6
  from hfendpoints.openai import Context, run
7
  from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
8
  TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
@@ -40,6 +41,7 @@ class NemoAsrHandler(Handler):
40
  def __init__(self, config: EndpointConfig):
41
  logger.info(config.repository)
42
  self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
 
43
 
44
  async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
45
  with logger.contextualize(request_id=ctx.request_id):
 
3
  from functools import partial
4
  from io import BytesIO
5
 
6
+ import torch
7
  from hfendpoints.openai import Context, run
8
  from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
9
  TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
 
41
  def __init__(self, config: EndpointConfig):
42
  logger.info(config.repository)
43
  self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
44
+ self._model = self._model.to(torch.bfloat16)
45
 
46
  async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
47
  with logger.contextualize(request_id=ctx.request_id):