Skip to content

Commit

Permalink
Feature/improve 0729 (#94)
Browse files Browse the repository at this point in the history
* add sentence-transformers; update hypyterparameters

* update default hyperparameters

* correct layer truncation
  • Loading branch information
SeanLee97 authored Jul 30, 2024
1 parent 6f4592d commit 899337f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,14 @@ from angle_emb.utils import cosine_similarity


angle = AnglE.from_pretrained('mixedbread-ai/mxbai-embed-2d-large-v1', pooling_strategy='cls').cuda()
# specify layer_index and embedding size to truncate embeddings
# truncate layer
angle = angle.truncate_layer(layer_index=22)
# specify embedding size to truncate embeddings
doc_vecs = angle.encode([
'The weather is great!',
'The weather is very good!',
'i am going to bed'
], layer_index=22, embedding_size=768)
], embedding_size=768)

for i, dv1 in enumerate(doc_vecs):
for dv2 in doc_vecs[i+1:]:
Expand All @@ -247,6 +249,9 @@ print(vec)

## 🕸️ Custom Train

💡 For more details, please refer to the [training and fintuning](https://angle.readthedocs.io/en/latest/notes/training.html).


### 🗂️ 1. Data Prepation

We currently support three dataset formats:
Expand All @@ -259,7 +264,7 @@ We currently support three dataset formats:

You need to prepare your data into huggingface `datasets.Dataset` in one of the formats in terms of your supervised data.

### 🚂 2. Train with CLI
### 🚂 2. Train with CLI [Recommended]

Use `angle-trainer` to train your AnglE model in cli mode.

Expand Down Expand Up @@ -316,7 +321,7 @@ angle.fit(
gradient_accumulation_steps=1,
loss_kwargs={
'cosine_w': 1.0,
'ibn_w': 1.0,
'ibn_w': 20.0,
'angle_w': 1.0,
'cosine_tau': 20,
'ibn_tau': 20,
Expand Down Expand Up @@ -350,6 +355,13 @@ print('Spearman\'s corrcoef:', corrcoef)
4️⃣ To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **It is worth to note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results.


## 5. Finetuning and Infering AnglE with `sentence-transformers`

- **Training:** SentenceTransformers also provides a implementation of [AnglE loss](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#angleloss). **But it is partially implemented and may not work well as the official code. We recommend to use the official `angle_emb` for fine-tuning AnglE model.**

- **Infering:** If your model is trained with `angle_emb`, and you want to use it with `sentence-transformers`. You can convert it to `sentence-transformers` model using the script `examples/convert_to_sentence_transformers.py`.


# 🫡 Citation

You are welcome to use our code and pre-trained models. If you use our code and pre-trained models, please support us by citing our work as follows:
Expand Down
4 changes: 2 additions & 2 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,8 @@ class AngleLoss:
:param dataset_format: Optional[str]. Default None.
"""
def __init__(self,
cosine_w: float = 1.0,
ibn_w: float = 1.0,
cosine_w: float = 0.0,
ibn_w: float = 20.0,
angle_w: float = 1.0,
cosine_tau: float = 20.0,
ibn_tau: float = 20.0,
Expand Down
6 changes: 4 additions & 2 deletions docs/notes/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ Specify `layer_index` and `embedding_size` to truncate embeddings.
angle = AnglE.from_pretrained('mixedbread-ai/mxbai-embed-2d-large-v1', pooling_strategy='cls').cuda()
# specify layer_index and embedding_size to truncate embeddings
# truncate layer
angle = angle.truncate_layer(layer_index=22)
# specify embedding size to truncate embeddings
doc_vecs = angle.encode([
'The weather is great!',
'The weather is very good!',
'i am going to bed'
], layer_index=22, embedding_size=768)
], embedding_size=768)
for i, dv1 in enumerate(doc_vecs):
for dv2 in doc_vecs[i+1:]:
Expand Down
12 changes: 11 additions & 1 deletion docs/notes/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ You can also train a sentence embedding model using the `angle_emb` library. Her
gradient_accumulation_steps=1,
loss_kwargs={
'cosine_w': 1.0,
'ibn_w': 1.0,
'ibn_w': 20.0,
'angle_w': 1.0,
'cosine_tau': 20,
'ibn_tau': 20,
Expand Down Expand Up @@ -172,6 +172,16 @@ You can also train a sentence embedding model using the `angle_emb` library. Her
4. To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **Note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results.


💡 Fine-tuning and Infering with `sentence-transformers`
---------------------------------------------------------------------------


1. **Training:** SentenceTransformers also provides a implementation of `AnglE loss <https://sbert.net/docs/package_reference/sentence_transformer/losses.html#angleloss>`_
. **But it is partially implemented and may not work well as the official code. We recommend to use the official `angle_emb` for fine-tuning AnglE model.**

2. **Infering:** If your model is trained with `angle_emb`, and you want to use it with `sentence-transformers`. You can convert it to `sentence-transformers` model using the script `examples/convert_to_sentence_transformers.py <https://github.com/SeanLee97/AnglE/blob/main/scripts/convert_to_sentence_transformer.py>`_.



💡 Others
-------------------------
Expand Down

0 comments on commit 899337f

Please sign in to comment.