Skip to content

Commit

Permalink
Updated to support CDI creation for RAG support and streaming and ddd…
Browse files Browse the repository at this point in the history
…ed examples to inject ChatMemory for simple RAG.
  • Loading branch information
TheEliteGentleman committed Feb 6, 2025
1 parent d688d9a commit e7c259f
Show file tree
Hide file tree
Showing 18 changed files with 120 additions and 118 deletions.
10 changes: 9 additions & 1 deletion examples/glassfish-car-booking/config/llm-config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ smallrye.llm.plugin.docRagRetriever.config.embeddingModel=lookup:default
smallrye.llm.plugin.docRagRetriever.config.maxResults=3
smallrye.llm.plugin.docRagRetriever.config.minScore=0.6


# Chat Memory used by ChatAiService class
smallrye.llm.plugin.chat-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.chat-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.chat-ai-service-memory.config.maxMessages=10

# Chat Memory used by FraudAiService class
smallrye.llm.plugin.fraud-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.fraud-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.fraud-ai-service-memory.config.maxMessages=5

smallrye.llm.embedding.store.in-memory.file=embedding.json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import java.time.temporal.ChronoUnit;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(tools = BookingService.class, chatMemoryMaxMessages = 10, chatLanguageModelName = "chat-model")
@RegisterAIService(tools = BookingService.class, chatMemoryName = "chat-ai-service-memory", chatLanguageModelName = "chat-model")
public interface ChatAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(chatMemoryMaxMessages = 5,

chatLanguageModelName = "chat-model")
@RegisterAIService(chatMemoryName = "fraud-ai-service-memory", chatLanguageModelName = "chat-model")
public interface FraudAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ smallrye.llm.plugin.docRagRetriever.config.embeddingModel=lookup:default
smallrye.llm.plugin.docRagRetriever.config.maxResults=3
smallrye.llm.plugin.docRagRetriever.config.minScore=0.6


# Chat Memory used by ChatAiService class
smallrye.llm.plugin.chat-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.chat-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.chat-ai-service-memory.config.maxMessages=10

# Chat Memory used by FraudAiService class
smallrye.llm.plugin.fraud-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.fraud-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.fraud-ai-service-memory.config.maxMessages=5

smallrye.llm.embedding.store.in-memory.file=embedding.json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(tools = BookingService.class, chatMemoryMaxMessages = 10, chatLanguageModelName = "chat-model")
@RegisterAIService(tools = BookingService.class, chatMemoryName = "chat-ai-service-memory", chatLanguageModelName = "chat-model")
public interface ChatAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(chatMemoryMaxMessages = 5,

chatLanguageModelName = "chat-model")
@RegisterAIService(chatMemoryName = "fraud-ai-service-memory", chatLanguageModelName = "chat-model")
public interface FraudAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ smallrye.llm.plugin.chat-model.config.topP=0.1
smallrye.llm.plugin.chat-model.config.timeout=PT120S
smallrye.llm.plugin.chat-model.config.max-retries=2


smallrye.llm.plugin.docRagRetriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
smallrye.llm.plugin.docRagRetriever.config.embeddingStore=lookup:default
smallrye.llm.plugin.docRagRetriever.config.embeddingModel=lookup:default
smallrye.llm.plugin.docRagRetriever.config.maxResults=3
smallrye.llm.plugin.docRagRetriever.config.minScore=0.6

# Chat Memory used by ChatAiService class
smallrye.llm.plugin.chat-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.chat-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.chat-ai-service-memory.config.maxMessages=10

# Chat Memory used by FraudAiService class
smallrye.llm.plugin.fraud-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.fraud-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.fraud-ai-service-memory.config.maxMessages=5

smallrye.llm.embedding.store.in-memory.file=embedding.json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(tools = BookingService.class, chatMemoryMaxMessages = 10, chatLanguageModelName = "chat-model")
@RegisterAIService(tools = BookingService.class, chatMemoryName = "chat-ai-service-memory", chatLanguageModelName = "chat-model")
public interface ChatAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(chatMemoryMaxMessages = 5,

chatLanguageModelName = "chat-model")
@RegisterAIService(chatMemoryName = "fraud-ai-service-memory", chatLanguageModelName = "chat-model")
public interface FraudAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import io.smallrye.llm.spi.RegisterAIService;

//@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(scope = ApplicationScoped.class, tools = BookingService.class, chatMemoryMaxMessages = 10)
@RegisterAIService(scope = ApplicationScoped.class, tools = BookingService.class, chatMemoryName = "chat-ai-service-memory")
public interface ChatAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dev.langchain4j.service.V;
import io.smallrye.llm.spi.RegisterAIService;

@RegisterAIService(chatMemoryMaxMessages = 5, chatLanguageModelName = "chat-model")
@RegisterAIService(chatMemoryName = "fraud-ai-service-memory", chatLanguageModelName = "chat-model")
public interface FraudAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ smallrye.llm.plugin.docRagRetriever.config.embeddingModel=lookup:default
smallrye.llm.plugin.docRagRetriever.config.maxResults=3
smallrye.llm.plugin.docRagRetriever.config.minScore=0.6

# Chat Memory used by ChatAiService class
smallrye.llm.plugin.chat-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.chat-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.chat-ai-service-memory.config.maxMessages=10

# Chat Memory used by FraudAiService class
smallrye.llm.plugin.fraud-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.fraud-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.fraud-ai-service-memory.config.maxMessages=5


smallrye.llm.embedding.store.in-memory.file=embedding.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(tools = BookingService.class, chatMemoryMaxMessages = 10)
@RegisterAIService(tools = BookingService.class, chatMemoryName = "chat-ai-service-memory")
public interface ChatAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import io.smallrye.llm.spi.RegisterAIService;

@SuppressWarnings("CdiManagedBeanInconsistencyInspection")
@RegisterAIService(chatMemoryMaxMessages = 5)
@RegisterAIService(chatMemoryName = "fraud-ai-service-memory")
public interface FraudAiService {

@SystemMessage("""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Microprofile server properties
quarkus.http.port=8090


smallrye.llm.plugin.chat-model.class=dev.langchain4j.model.azure.AzureOpenAiChatModel
smallrye.llm.plugin.chat-model.config.api-key=${azure.openai.api.key}
smallrye.llm.plugin.chat-model.config.endpoint=${azure.openai.endpoint}
Expand All @@ -13,14 +12,21 @@ smallrye.llm.plugin.chat-model.config.timeout=120s
smallrye.llm.plugin.chat-model.config.max-retries=2
#smallrye.llm.plugin.chat-model.config.logRequestsAndResponses=false


smallrye.llm.plugin.docRagRetriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
smallrye.llm.plugin.docRagRetriever.config.embeddingStore=lookup:default
smallrye.llm.plugin.docRagRetriever.config.embeddingModel=lookup:default
smallrye.llm.plugin.docRagRetriever.config.maxResults=3
smallrye.llm.plugin.docRagRetriever.config.minScore=0.6

# Chat Memory used by ChatAiService class
smallrye.llm.plugin.chat-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.chat-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.chat-ai-service-memory.config.maxMessages=10

# Chat Memory used by FraudAiService class
smallrye.llm.plugin.fraud-ai-service-memory.class=dev.langchain4j.memory.chat.MessageWindowChatMemory
smallrye.llm.plugin.fraud-ai-service-memory.scope=jakarta.enterprise.context.ApplicationScoped
smallrye.llm.plugin.fraud-ai-service-memory.config.maxMessages=5

smallrye.llm.embedding.store.in-memory.file=embedding.json

Expand All @@ -32,6 +38,3 @@ fraud.memory.max.messages=20

# Location of documents to RAG
app.docs-for-rag.dir=docs-for-rag



Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package io.smallrye.llm.aiservice;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -11,115 +9,91 @@

import org.jboss.logging.Logger;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.smallrye.llm.core.langchain4j.core.config.spi.ChatMemoryFactoryProvider;
import io.smallrye.llm.spi.RegisterAIService;

public class CommonAIServiceCreator {

private static final Logger LOGGER = Logger.getLogger(CommonAIServiceCreator.class);

@SuppressWarnings("unchecked")
public static <X> X create(Instance<Object> lookup, Class<X> interfaceClass) {
RegisterAIService annotation = interfaceClass.getAnnotation(RegisterAIService.class);
Instance<ChatLanguageModel> chatLanguageModel = getInstance(lookup, ChatLanguageModel.class,
annotation.chatLanguageModelName());
Instance<StreamingChatLanguageModel> streamingChatLanguageModel = getInstance(lookup, StreamingChatLanguageModel.class,
annotation.streamingChatLanguageModelName());
Instance<ContentRetriever> contentRetriever = getInstance(lookup, ContentRetriever.class,
annotation.contentRetrieverName());
try {
AiServices<?> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel.isResolvable()) {
LOGGER.info("ChatLanguageModel " + chatLanguageModel.get());
aiServices.chatLanguageModel(chatLanguageModel.get());
}
if (contentRetriever.isResolvable()) {
LOGGER.info("ContentRetriever " + contentRetriever.get());
aiServices.contentRetriever(contentRetriever.get());
}
if (annotation.tools() != null && annotation.tools().length > 0) {
List<Object> tools = new ArrayList<>(annotation.tools().length);
for (Class<?> toolClass : annotation.tools()) {
try {
tools.add(toolClass.getConstructor(null).newInstance(null));
} catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException ex) {
}
}
aiServices.tools(tools);
}
Instance<RetrievalAugmentor> retrievalAugmentor = getInstance(lookup, RetrievalAugmentor.class,
annotation.retrievalAugmentorName());

ChatMemoryProvider chatMemoryProvider = createChatMemoryProvider(lookup, interfaceClass, annotation);
if (chatMemoryProvider != null) {
aiServices.chatMemoryProvider(chatMemoryProvider);
} else {
aiServices.chatMemory(
ChatMemoryFactoryProvider.getChatMemoryFactory().getChatMemory(lookup,
annotation.chatMemoryMaxMessages()));
AiServices<X> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel != null && chatLanguageModel.isResolvable()) {
LOGGER.info("ChatLanguageModel " + chatLanguageModel.get());
aiServices.chatLanguageModel(chatLanguageModel.get());
}
if (streamingChatLanguageModel != null && streamingChatLanguageModel.isResolvable()) {
LOGGER.info("StreamingChatLanguageModel " + streamingChatLanguageModel.get());
aiServices.streamingChatLanguageModel(streamingChatLanguageModel.get());
}
if (contentRetriever != null && contentRetriever.isResolvable()) {
LOGGER.info("ContentRetriever " + contentRetriever.get());
aiServices.contentRetriever(contentRetriever.get());
}
if (retrievalAugmentor != null && retrievalAugmentor.isResolvable()) {
LOGGER.info("RetrievalAugmentor " + retrievalAugmentor.get());
aiServices.retrievalAugmentor(retrievalAugmentor.get());
}
if (annotation.tools() != null && annotation.tools().length > 0) {
List<Object> tools = new ArrayList<>(annotation.tools().length);
for (Class<?> toolClass : annotation.tools()) {
try {
tools.add(toolClass.getConstructor((Class<?>[]) null).newInstance((Object[]) null));
} catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException ex) {
}
}
aiServices.tools(tools);
}

ModerationModel moderationModel = findModerationModel(lookup, interfaceClass, annotation);
if (moderationModel != null) {
aiServices.moderationModel(moderationModel);
}
return (X) aiServices.build();
} catch (Exception e) {
throw new RuntimeException(e);
Instance<ChatMemory> chatMemory = getInstance(lookup, ChatMemory.class,
annotation.chatMemoryName());
if (chatMemory != null && chatMemory.isResolvable()) {
LOGGER.info("ChatMemory " + chatMemory.get());
aiServices.chatMemory(chatMemory.get());
}
}

private static <X> Instance<X> getInstance(Instance<Object> lookup, Class<X> type, String name) {
LOGGER.info("Getinstance of '" + type + "' with name '" + name + "'");
if (name == null || name.isBlank()) {
return lookup.select(type);
Instance<ChatMemoryProvider> chatMemoryProvider = getInstance(lookup, ChatMemoryProvider.class,
annotation.chatMemoryProviderName());
if (chatMemoryProvider != null && chatMemoryProvider.isResolvable()) {
LOGGER.info("ChatMemoryProvider " + chatMemoryProvider.get());
aiServices.chatMemoryProvider(chatMemoryProvider.get());
}
return lookup.select(type, NamedLiteral.of(name));
}

private static ModerationModel findModerationModel(Instance<Object> lookup, Class<?> interfaceClass,
RegisterAIService registerAIService) {
//Get all methods.
for (Method method : interfaceClass.getMethods()) {
Moderate moderate = method.getAnnotation(Moderate.class);
if (moderate != null) {
Instance<ModerationModel> moderationModelInstance = getInstance(lookup, ModerationModel.class,
registerAIService.moderationModelName());
if (moderationModelInstance != null && moderationModelInstance.isResolvable())
return moderationModelInstance.get();
}
Instance<ModerationModel> moderationModelInstance = getInstance(lookup, ModerationModel.class,
annotation.moderationModelName());
if (moderationModelInstance != null && moderationModelInstance.isResolvable()) {
LOGGER.info("ModerationModel " + moderationModelInstance.get());
aiServices.moderationModel(moderationModelInstance.get());
}

return null;
return aiServices.build();
}

private static ChatMemoryProvider createChatMemoryProvider(Instance<Object> lookup, Class<?> interfaceClass,
RegisterAIService registerAIService) {
//Get all methods.
for (Method method : interfaceClass.getMethods()) {
for (Parameter parameter : method.getParameters()) {
MemoryId memoryIdAnnotation = parameter.getAnnotation(MemoryId.class);
if (memoryIdAnnotation != null) {
Instance<ChatMemoryStore> chatMemoryStore = getInstance(lookup, ChatMemoryStore.class,
registerAIService.chatMemoryStoreName());
if (chatMemoryStore == null || !chatMemoryStore.isResolvable()) {
throw new IllegalStateException("Unable to resolve a ChatMemoryStore for your ChatMemoryProvider.");
}
private static <X> Instance<X> getInstance(Instance<Object> lookup, Class<X> type, String name) {
LOGGER.info("CDI get instance of type '" + type + "' with name '" + name + "'");
if (name != null && !name.isBlank()) {
if ("#default".equals(name))
return lookup.select(type);

ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(registerAIService.chatMemoryMaxMessages())
.chatMemoryStore(chatMemoryStore.get())
.build();
return chatMemoryProvider;
}
}
return lookup.select(type, NamedLiteral.of(name));
}

return null;
Expand Down
Loading

0 comments on commit e7c259f

Please sign in to comment.