Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT-932] [BT-891] Update load_model and load_spec references to use repository util #52

Merged
merged 4 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/.envrc.example
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
use flake

export OPENAI_API_KEY=""
66 changes: 42 additions & 24 deletions lib/excision_web/controllers/classifier_controller.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ defmodule ExcisionWeb.ClassifierController do
use ExcisionWeb, :controller
use OpenApiSpex.ControllerSpecs

require Logger

alias Excision.Excisions
alias Excision.Excisions.Classifier
alias Excision.Util

action_fallback ExcisionWeb.FallbackController

Expand Down Expand Up @@ -84,8 +87,9 @@ defmodule ExcisionWeb.ClassifierController do
# TODO: this is really slow, need GenServer (Agent?) to keep model in memory
# TODO: read model name from classifier
model_name = classifier.base_model_name
repository = Util.build_bumblebee_model_repository(model_name)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer(repository)
checkpoint_path = classifier.checkpoint_path

{:ok, spec} =
Expand All @@ -95,7 +99,7 @@ defmodule ExcisionWeb.ClassifierController do

num_labels = decision_site.choices |> Enum.count()
spec = Bumblebee.configure(spec, num_labels: num_labels)
{:ok, model} = Bumblebee.load_model({:hf, model_name}, spec: spec)
{:ok, model} = Bumblebee.load_model(repository, spec: spec)

params =
[checkpoint_path, "parameters.nx"]
Expand Down Expand Up @@ -192,34 +196,48 @@ defmodule ExcisionWeb.ClassifierController do
resp_body =
case Plug.Conn.get_resp_header(resp, "content-encoding") do
["gzip"] ->
:zlib.gunzip(resp.resp_body)
{:ok, :zlib.gunzip(resp.resp_body)}

["br"] ->
{:ok, data} = :brotli.decode(resp.resp_body)
data
:brotli.decode(resp.resp_body)

_ ->
resp.resp_body
if resp.status >= 400 do
Logger.error(
"Got failure response while proxying request for classifier #{classifier.name}: #{resp.status} #{resp.resp_body}"
)

{:error, resp.resp_body}
else
{:ok, resp.resp_body}
end
end
|> Jason.decode!()

# record the decision
Excision.Excisions.create_decision(%{
decision_site_id: decision_site.id,
classifier_id: classifier.id,
input: Jason.encode!(req_body["messages"]),
prediction_id:
decision_site.choices
|> Enum.find(fn c ->
c.name ==
resp_body["choices"]
|> hd()
|> then(& &1["message"]["content"])
|> then(&Jason.decode!/1)
|> then(& &1["value"])
end)
|> then(& &1.id)
})
deserialized_body = resp_body |> elem(1) |> Jason.decode!()

case resp_body do
{:ok, _} ->
# record the decision
Excision.Excisions.create_decision(%{
decision_site_id: decision_site.id,
classifier_id: classifier.id,
input: Jason.encode!(req_body["messages"]),
prediction_id:
decision_site.choices
|> Enum.find(fn c ->
c.name ==
deserialized_body["choices"]
|> hd()
|> then(& &1["message"]["content"])
|> then(&Jason.decode!/1)
|> then(& &1["value"])
end)
|> then(& &1.id)
})

{:error, _} ->
nil
end

resp
end
Expand Down
Loading