-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathserving.py
61 lines (49 loc) · 2.38 KB
/
serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from flask import Flask, request, jsonify
import numpy as np
from transformers import AutoTokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
app = Flask(__name__)
pipeline = None
GLOBAL_SERVER = None
class Server:
def __init__(self, approx_model_name, target_model_name) -> None:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info("begin load models")
self._small_model = AutoModelForCausalLM.from_pretrained(approx_model_name, trust_remote_code=True).to(self._device)
self._large_model = AutoModelForCausalLM.from_pretrained(target_model_name, trust_remote_code=True).to(self._device)
self._tokenizer = AutoTokenizer.from_pretrained(approx_model_name)
logging.info("fininsh load models")
self.num_tokens = 40
self.top_k = 10
self.top_p = 0.9
def process_request(self, request : str) -> torch.Tensor:
input_str = request['prompt']
logging.info(f"recieve request {input_str}")
input_ids = self._tokenizer.encode(input_str, return_tensors='pt').to(self._device)
output = speculative_sampling(input_ids,
self._small_model,
self._large_model, self.num_tokens,
top_k = self.top_k,
top_p = self.top_p)
generated_text = self._tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Set up a route to listen for inference requests
@app.route('/predict', methods=['POST'])
def predict():
# Check the content type of the request
if request.headers['Content-Type'] != 'application/json':
return jsonify({'error': 'Invalid content type'})
# Get the request data
request_data = request.json
# Perform inference
result = GLOBAL_SERVER.process_request(request_data)
# Return the inference results
return jsonify(result)
if __name__ == '__main__':
GLOBAL_SERVER = Server(approx_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
target_model_name="/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1")
# Start the Flask service
app.run(host='0.0.0.0', port=5000)