Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched inference #108

Closed
wants to merge 0 commits into from
Closed

batched inference #108

wants to merge 0 commits into from

Conversation

varshith15
Copy link
Contributor

@varshith15 varshith15 commented Jul 31, 2024

  • batched inference llama unit testcase
  • batched inference llava unit testcase
  • take list as input for prompt, image_str, request_id
  • fix callbacks order
  • online batching in chatgpt api
  • top_p!=1 for batched patch
  • early stopping in the batch if the stop has eos occurred before, currently only check the last token
  • online batching conditions update (timeout, max_batch etc)

@varshith15
Copy link
Contributor Author

varshith15 commented Jul 31, 2024

hey @AlexCheema
to start with, i am assuming that inflight batching isnt needed right away (have to revamp a lot the api to make inflight happen)
we can just batch requests in the api section and do inference and split them back in the api section thoughts?

@AlexCheema
Copy link
Contributor

hey @AlexCheema to start with, i am assuming that inflight batching isnt needed right away (have to revamp a lot the api to make inflight happen) we can just batch requests in the api section and do inference and split them back in the api section thoughts?

Sounds good. Yes, no need for inflight batching - we can add that in a subsequent PR.
Lets do this!

Related issue: #1

@varshith15
Copy link
Contributor Author

varshith15 commented Aug 12, 2024

Hey @AlexCheema
I was just working on the batching of requests at the endpoint part but it just seems a little gimmck-y to online batch it ourselves and also issues concerning stream callback

My current idea for online batching is to basically combine the requests ("req1_req2_req3..") and then check if req_id is present in the broadcasted result id and split it based on the position of the id but seems a bit gimmick-y

I think its just better to process requests concurrently using https://github.com/omnilib/aiomultiprocess (#4) rather than online batching them

If the user gives a batch as input we could process it but i think its best that we dont online batch it

thoughts?

@varshith15
Copy link
Contributor Author

ive thought of a better way, just need to update functions to take list of req_ids, prompts and img_strs and broadcast back the specific results based on the index of the req in the back - that way dont have to change a lot of code in the chatgpt api

@varshith15
Copy link
Contributor Author

hey @AlexCheema, batched inference with online batching works now. There are few more small patches required, ill push them in a bit. Could you please review the idea meanwhile.

@varshith15
Copy link
Contributor Author

@AlexCheema it's done! PRM

@varshith15 varshith15 marked this pull request as ready for review August 18, 2024 11:11
@AlexCheema
Copy link
Contributor

Thanks @varshith15 I will take a proper look at some point next week.

@AlexCheema
Copy link
Contributor

I haven't had a proper look yet, but I'm wondering what the behaviour is when a request is being processed and another request comes in?

@varshith15
Copy link
Contributor Author

currently the requests just keep getting processed one after the other, by processes i mean after the await process_prompt is done but maybe a semaphore is needed

@AlexCheema
Copy link
Contributor

currently the requests just keep getting processed one after the other, by processes i mean after the await process_prompt is done but maybe a semaphore is needed

I think this might cause issues currently since the kv_cache is shared. Can you test this out to see if it works? It doesn't work on main but maybe with your changes it works now.

@varshith15
Copy link
Contributor Author

varshith15 commented Aug 18, 2024

I've tested it out a couple of different scenarios with sending the reqs to the service, it works as expected
trying more testing

@varshith15
Copy link
Contributor Author

varshith15 commented Aug 19, 2024

currently the requests just keep getting processed one after the other, by processes i mean after the await process_prompt is done but maybe a semaphore is needed

I think this might cause issues currently since the kv_cache is shared. Can you test this out to see if it works? It doesn't work on main but maybe with your changes it works now.

@AlexCheema i think you are right, there a bug when theres a request running already and a new request comes, the older one is not able to return the result, can you expand on what do you mean by the issue is due to kv_cache being shared?

@varshith15
Copy link
Contributor Author

varshith15 commented Aug 22, 2024

@AlexCheema could you expand on why concurrent processing doesn't work as the kv_csche is shared?

I haven't debugged yet, will do it over the weekend

@AlexCheema
Copy link
Contributor

@AlexCheema could you expand on why concurrent processing doesn't work as the kv_csche is shared?

I haven't debugged yet, will do it over the weekend

I pushed something for MLX that uses a LRU cache with a bunch of KV caches that should fix it. Can you merge the latest exo and fix conflicts then I will take a look through the PR.

exo/helpers.py Outdated

def on_next(self, callback: Callable[..., None]) -> None:
self.observers.append(callback)

def set(self, *args: T) -> None:
self.result = args
self.result.append(args)
Copy link
Contributor Author

@varshith15 varshith15 Aug 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexCheema the concurrent execution issue is due to this

@varshith15
Copy link
Contributor Author

varshith15 commented Aug 24, 2024

@AlexCheema this is done, works as expected with online batching and concurrent requests now, please review

so a few things

  1. need to add semaphore functionality in chatgpt api matching the LRU cache number
  2. ive moved all the functions from taking a single request_id to a list of request_ids, ive updated model inference part for the MLX part, please comment on the other engines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants