velocity-ai commited on
Commit
d9c2292
·
verified ·
1 Parent(s): 0ebdffc

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +3 -4
code/inference.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import json
3
  import torch
4
  import torch.nn as nn
5
- from transformers import AutoModel, AutoTokenizer, AutoConfig
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
@@ -43,11 +43,10 @@ def model_fn(model_dir, context=None):
43
 
44
  # Load config and specify it's a Phi3Config
45
  config = AutoConfig.from_pretrained(model_id,
46
- num_labels=2,
47
  trust_remote_code=True)
48
 
49
- # Load base model
50
- base_model = AutoModel.from_pretrained(
51
  model_id,
52
  config=config,
53
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
 
2
  import json
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
 
43
 
44
  # Load config and specify it's a Phi3Config
45
  config = AutoConfig.from_pretrained(model_id,
 
46
  trust_remote_code=True)
47
 
48
+ # Load base model using AutoModelForCausalLM
49
+ base_model = AutoModelForCausalLM.from_pretrained(
50
  model_id,
51
  config=config,
52
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,