Skip to content

Commit

Permalink
Interface update
Browse files Browse the repository at this point in the history
  • Loading branch information
Denis2054 authored Dec 30, 2023
1 parent 1e28fa9 commit 22dd17d
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions Chapter03/BERT_Fine_Tuning_Sentence_Classification_GPU.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"\n",
"This notebook shows how to fine-tune a transformer model. Many pretrained and fine-tuned transformer models are available online. Some models, such as OpenAI LLMs, do not need to be fine-tuned for many tasks.\n",
"\n",
"September 02, 2023: A Gradio user interface was added at the end of the notebook to interact with the model.\n",
"December,30 2023: A user interface was added at the end of the notebook to interact with the model.\n",
"\n",
"It is highly recommended to understand this notebook to grasp the architecture of transformer models.\n",
"\n",
Expand Down Expand Up @@ -4767,7 +4767,7 @@
"id": "wyArQWINKyHc"
},
"source": [
"# Creating an interface for the trained model with Gradio\n",
"# Creating an interface for the trained model\n",
"\n"
]
},
Expand Down Expand Up @@ -4944,15 +4944,6 @@
"Click on `Clear` to enter a new sentence. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade typing_extensions"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -4961,7 +4952,7 @@
},
"outputs": [],
"source": [
"!pip install gradio"
"!pip install ipywidgets"
]
},
{
Expand Down Expand Up @@ -5006,23 +4997,37 @@
}
],
"source": [
"import gradio as gr\n",
"import ipywidgets as widgets\n",
"from IPython.display import display\n",
"\n",
"def gradio_predict_interface(sentence):\n",
"def model_predict_interface(sentence):\n",
" prediction = predict(sentence, model, tokenizer)\n",
" if prediction == 0:\n",
" return \"Grammatically Incorrect\"\n",
" elif prediction == 1:\n",
" return \"Grammatically Correct\"\n",
" else:\n",
" return f\"Label: {prediction}\" # Generic case in case there are more labels or some other representation\n",
" return f\"Label: {prediction}\"\n",
"\n",
"text_input = widgets.Textarea(\n",
" placeholder='Type something',\n",
" description='Sentence:',\n",
" disabled=False,\n",
" layout=widgets.Layout(width='100%', height='50px') # Adjust width and height here\n",
")\n",
"\n",
"output_label = widgets.Label(\n",
" value='',\n",
" layout=widgets.Layout(width='100%', height='25px'), # Adjust width and height here\n",
" style={'description_width': 'initial'}\n",
")\n",
"\n",
"def on_text_submit(change):\n",
" output_label.value = model_predict_interface(change.new)\n",
"\n",
"interface = gr.Interface(fn=gradio_predict_interface,\n",
" inputs=\"text\",\n",
" outputs=\"label\",\n",
" live=True)\n",
"text_input.observe(on_text_submit, names='value')\n",
"\n",
"interface.launch()\n"
"display(text_input, output_label)\n"
]
}
],
Expand Down

0 comments on commit 22dd17d

Please sign in to comment.