Skip to content

Commit

Permalink
Merge pull request #34 from mc-cat-tty/improvements
Browse files Browse the repository at this point in the history
Improvements
  • Loading branch information
mc-cat-tty authored Feb 8, 2024
2 parents 0d184fa + 41b3a2c commit cf69e6a
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 104 deletions.
4 changes: 2 additions & 2 deletions placerank/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from placerank.query_expansion import *
from placerank.ir_model import *
from placerank.models import *
from placerank.sentiment import BaseSentimentWeightingModel
from placerank.dataset import ReviewsDatabase
from placerank.config import INDEX_DIR, HELP_FILENAME, DATASET_CACHE_FILE, HF_CACHE, REVIEWS_DB, REVIEWS_INDEX
from whoosh.index import open_dir
Expand All @@ -22,8 +23,7 @@ def main() -> None:
window = Window(readme.read())

idx = open_dir(INDEX_DIR)
model = IRModel(WhooshSpellCorrection, ThesaurusQueryExpansion(HF_CACHE), idx)
model = SentimentAwareIRModel(WhooshSpellCorrection, ThesaurusQueryExpansion(HF_CACHE), idx, SentimentRanker(REVIEWS_INDEX), TF_IDF)
model = UnionIRModel(WhooshSpellCorrection, ThesaurusQueryExpansion(HF_CACHE), idx, BaseSentimentWeightingModel(REVIEWS_INDEX))
presenter = Presenter(model, DATASET_CACHE_FILE, ReviewsDatabase(REVIEWS_DB))
loop = MainLoop(window, palette=PALETTE)
loop.run()
Expand Down
54 changes: 15 additions & 39 deletions placerank/ir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@

from placerank.views import ResultView, QueryView, ReviewsIndex
from placerank.query_expansion import QueryExpansionService
from placerank.sentiment import BaseSentimentWeightingModel
from placerank import config

class IRModel(ABC):
def __init__(
self,
spell_corrector: Type[SpellCorrectionService],
query_expander: QueryExpansionService,
index: Index,
weighting_model: WeightingModel = BM25F,
weighting_model: WeightingModel = BM25F(),
connector: str = 'AND'
):
self.spell_corrector = spell_corrector(self)
Expand All @@ -51,20 +53,23 @@ def __init__(
self.connector = connector

def get_query_parser(self, query: QueryView) -> qparser.QueryParser:
return qparser.MultifieldParser([i.name.lower() for i in query.search_fields], self.index.schema)
return qparser.MultifieldParser([f.name.lower() for f in query.search_fields], self.index.schema)

def set_autoexpansion(self, autoexpansion: bool):
self._autoexpansion = autoexpansion

def search(self, query: QueryView, **kwargs) -> Tuple(List[ResultView], int):
if isinstance(self.weighting_model, BaseSentimentWeightingModel):
self.weighting_model.set_user_sentiment(query.sentiment_tags)

expanded_query = self.query_expander.expand(query.textual_query, connector = self.connector)

parser = qparser.QueryParser('room_type', self.index.schema)
parser.add_plugin(qparser.OperatorsPlugin())
room_type = parser.parse(query.room_type) if query.room_type else None

parser = self.get_query_parser(query)
query = parser.parse(expanded_query if self._autoexpansion else query.textual_query)

with self.index.searcher(weighting = self.weighting_model) as s:
hits = s.search(query, filter = room_type, **kwargs)
tot = len(hits)
Expand Down Expand Up @@ -95,41 +100,12 @@ def correct(self, query: QueryView) -> str:

return corrected_query.string

class SentimentRanker:
def __init__(self, reviews_index_path: str):
self.__reviews_index = ReviewsIndex(reviews_index_path)
self

@staticmethod
def __cosine_similarity(doc: dict, query: dict):
"""
Cosine similarity
"""

d_norm = math.sqrt(sum(v**2 for v in doc.values()))
q_norm = math.sqrt(sum(v**2 for v in query.values()))

num = sum(doc[k]*query[k] for k in (doc.keys() & query.keys()))
denom = (d_norm * q_norm)

return num / denom if denom else 0

def __score(self, doc, sentiment):
return SentimentRanker.__cosine_similarity(self.__get_sentiment_for(doc), sentiment) * doc.score

def __get_sentiment_for(self, doc):
return self.__reviews_index.get_sentiment_for(int(doc.id))

def rank(self, docs: List[ResultView], sentiment: str) -> List[ResultView]:
sim_docs = map(lambda d: (d, self.__score(d, sentiment)), docs)
return list(map(itemgetter(0), sorted(sim_docs, key=itemgetter(1), reverse=True)))


if __name__ == "__main__":
results = [ResultView(470330, 0, 0, 0.2), ResultView(267652, 0, 0, 0.9), ResultView(321014, 0, 0, 0.11)]
sentiment = {'optimism': 1, 'approval': 1}
# if __name__ == "__main__":
# results = [ResultView(470330, 0, 0, 0.2), ResultView(267652, 0, 0, 0.9), ResultView(321014, 0, 0, 0.11)]
# sentiment = {'optimism': 1, 'approval': 1}

a = SentimentRanker()
ranked = a.rank(results, sentiment)
for r in ranked:
print(r.id)
# a = SentimentRanker(config.REVIEWS_INDEX)
# ranked = a.rank(results, sentiment)
# for r in ranked:
# print(r.id)
68 changes: 23 additions & 45 deletions placerank/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,38 @@
Cooked models ready to ship
"""

import pydash
from whoosh.scoring import WeightingModel, BM25F
from whoosh.index import Index, open_dir
from whoosh.query import *
from whoosh.qparser import QueryParser
from whoosh.qparser import QueryParser, syntax
from whoosh.qparser.plugins import MultifieldPlugin
from typing import Type, Tuple, List
from placerank import query_expansion

from placerank.views import ResultView, QueryView, ReviewsIndex, SearchFields
from placerank.ir_model import IRModel, NoSpellCorrection, SentimentRanker, SpellCorrectionService, WhooshSpellCorrection
from placerank.query_expansion import NoQueryExpansion, QueryExpansionService, ThesaurusQueryExpansion
from placerank.views import *
from placerank.ir_model import *
from placerank.query_expansion import *
from placerank.sentiment import *
from placerank.config import HF_CACHE, INDEX_DIR, REVIEWS_INDEX

class UnionIRModel(IRModel):
def search(self, query: QueryView, **kwargs):
self.connector = 'OR'
union_query = (
pydash.chain(query.textual_query.split())
.intersperse(' OR ')
.value()
)
union_query = ' '.join(union_query)
query = QueryView(
union_query,
query.search_fields,
query.room_type,
query.sentiment_tags
)
return super().search(query, **kwargs)

class SentimentAwareIRModel(UnionIRModel):
def __init__(
self,
spell_corrector: Type[SpellCorrectionService],
query_expander: QueryExpansionService,
index: Index,
sentiment_ranker: SentimentRanker,
weighting_model: WeightingModel = BM25F
):
super().__init__(spell_corrector, query_expander, index, weighting_model)
self.sentiment_ranker = sentiment_ranker
class MultifieldUnionPlugin(MultifieldPlugin):
def do_multifield(self, parser, group):
ast = super().do_multifield(parser, group)
lin_ast = [n for n in ast]
return syntax.OrGroup(lin_ast)


class UnionIRModel(IRModel):
def get_query_parser(self, query: QueryView) -> QueryParser:
p = QueryParser(None, self.index.schema)
mfp = MultifieldUnionPlugin([f.name.lower() for f in query.search_fields])
p.add_plugin(mfp)
return p

def search(self, query: QueryView, **kwargs) -> Tuple[List[ResultView], int]:
sentiment = {k: 1 for k in query.sentiment_tags.split(" ")}
limit = kwargs.get('limit', None)
kwargs['limit'] = None
docs, dlen = super().search(query, **kwargs)
sent_ranked_docs = self.sentiment_ranker.rank(docs, sentiment)[:limit]

return (sent_ranked_docs, dlen)

def main():
idx = open_dir(INDEX_DIR)
sentiment_model = SentimentAwareIRModel(NoSpellCorrection, NoQueryExpansion(), idx, SentimentRanker(REVIEWS_INDEX))
sentiment_model = UnionIRModel(NoSpellCorrection, NoQueryExpansion(), idx, AdvancedSentimentWeightingModel(REVIEWS_INDEX))
sentiment_res = sentiment_model.search(
QueryView(
textual_query = u'apartment in manhattan', # Stopwords like 'in' are removed
Expand All @@ -64,9 +42,9 @@ def main():
)
)[1]

qe_model = IRModel(NoSpellCorrection, ThesaurusQueryExpansion(HF_CACHE), idx)
qe_model.set_autoexpansion(False)
qe_res = qe_model.search(
union_model = UnionIRModel(NoSpellCorrection, NoQueryExpansion(), idx)
union_model.set_autoexpansion(False)
qe_res = union_model.search(
QueryView(
textual_query = u'cheap stay',
search_fields = SearchFields.DESCRIPTION | SearchFields.NEIGHBORHOOD_OVERVIEW | SearchFields.NAME
Expand Down
34 changes: 18 additions & 16 deletions placerank/query_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ def expand(self, query: str, max_results: int = 2, confidence_threshold: float =

expanded_query = (
pydash.chain(expanded_query)
.map(
lambda sublist:
pydash.chain(sublist)
.intersperse('OR')
.value()
)
.map(lambda s: ['('] + s + [')'])
.intercalate(connector)
# .map(
# lambda sublist:
# pydash.chain(sublist)
# .intersperse('OR')
# .value()
# )
# .map(lambda s: ['('] + s + [')'])
# .intercalate(connector)
.flatten_deep()
.value()
)
expanded_query = ' '.join(expanded_query)
Expand Down Expand Up @@ -198,14 +199,15 @@ def expand(self, query: str, max_results: int = 2, confidence_threshold: float =

expanded_query = (
pydash.chain(expanded_query)
.map(
lambda sublist:
pydash.chain(sublist)
.intersperse('OR')
.value()
)
.map(lambda s: ['('] + s + [')'])
.intercalate(connector)
# .map(
# lambda sublist:
# pydash.chain(sublist)
# .intersperse('OR')
# .value()
# )
# .map(lambda s: ['('] + s + [')'])
# .intercalate(connector)
.flatten_deep()
.value()
)
expanded_query = ' '.join(expanded_query)
Expand Down
74 changes: 72 additions & 2 deletions placerank/sentiment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from transformers import BertTokenizer, AutoModelForSequenceClassification, pipeline
from whoosh.scoring import WeightingModel, BM25F
from placerank.views import ReviewsIndex
import math
import re
import pydash

class GoEmotionsClassifier:

Expand All @@ -24,9 +29,74 @@ def classify_texts(self, texts):
return self.goemotions(texts)


# Example usage:
class BaseSentimentWeightingModel(BM25F):
def __init__(self, reviews_index_path: str, *args, **kwargs):
self.use_final = True
self._user_sentiment = None
self._reviews_index = ReviewsIndex(reviews_index_path)
super().__init__(*args, **kwargs)

def _cosine_similarity(self, doc: dict, query: dict):
"""
Cosine similarity
"""

d_norm = math.sqrt(sum(v**2 for v in doc.values()))
q_norm = math.sqrt(sum(v**2 for v in query.values()))

num = sum(doc[k]*query[k] for k in (doc.keys() & query.keys()))
denom = (d_norm * q_norm)

return num / denom if denom else 0

def _sentiment_score(self, listing_id, sentiment):
return self._cosine_similarity(self._get_sentiment_for(listing_id), sentiment)

def _get_sentiment_for(self, listing_id):
return self._reviews_index.get_sentiment_for(int(listing_id))

def _combine_scores(self, textual_score, sentiment_score):
return textual_score * sentiment_score

def set_user_sentiment(self, user_sentiment):
user_sentiment = user_sentiment.strip() + ' '
negated_sentiments = (
pydash.chain(re.findall(r'\s*not\s+.+?\s+', user_sentiment))
.map(lambda s: s.strip().split(' ')[1])
.value()
)

self._user_sentiment = {k: 1 if k not in negated_sentiments else -1 for k in user_sentiment.split(" ")}
if 'not' in self._user_sentiment: del self._user_sentiment['not']


def final(self, searcher, docnum, textual_score):
textual_score = super().final(searcher, docnum, textual_score)

if not self._user_sentiment: return textual_score

id = searcher.stored_fields(docnum)['id']
sentiment_score = self._sentiment_score(id, self._user_sentiment)
return self._combine_scores(textual_score, sentiment_score)


class AdvancedSentimentWeightingModel(BaseSentimentWeightingModel):
def combine_scores(self, textual_score, sentiment_score, id):
tmp = textual_score * sentiment_score * self._reviews_index.get_sentiment_len_for(id)
return tmp

def final(self, searcher, docnum, textual_score):
textual_score = super().final(searcher, docnum, textual_score)

if not self._user_sentiment: return textual_score

id = searcher.stored_fields(docnum)['id']
sentiment_score = self._sentiment_score(id, self._user_sentiment)
return self.combine_scores(textual_score, sentiment_score, id)


if __name__ == "__main__":
classifier = GoEmotionsClassifier()
texts = ["its happened before?! love my hometown of beautiful new ken 😂😂"]
texts = ["it's happened before?! love my hometown of beautiful new ken 😂😂"]
results = classifier.classify_texts(texts)
print(results)
4 changes: 4 additions & 0 deletions placerank/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def __init__(self, path = "reviews.pickle"):

def __todate(self, s: str):
return datetime.strptime(s, "%Y-%m-%d")

def get_sentiment_len_for(self, key):
tmp = self.index.get(int(key), {})
return len(tmp)

def get_sentiment_for(self, key, tau_div = 90):
"""
Expand Down

0 comments on commit cf69e6a

Please sign in to comment.