Blog

Enhancing ChatGPT with Prompt Engineering and Token Counting in Java

By  
Marcus Hellberg
Marcus Hellberg
·
On Jul 28, 2023 5:14:26 PM
·

This article is part three of the Building an AI chatbot in Java series, where we're building a custom AI chatbot application using Hilla, Spring Boot, React, OpenAI (ChatGPT), and Pinecone. The chatbot is designed to answer Vaadin Flow and Hilla development questions using up-to-date documentation as reference material.

In this part, we'll create a service that provides context-aware chat completions by calling the Pinecone service to fetch relevant documentation sections, then constructing a ChatGPT prompt and call the OpenAI service chat completion with it.

Requirements

The tutorial assumes you have a Hilla project based on Spring Boot and React that includes the services built in the previous two parts. You can find the complete source below if you are new to the series.

Add the following dependency to your pom.xml to count tokens.

<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
<version>0.6.1</version> <!-- check latest version -->
</dependency>

Source code for the completed application

You can find the completed source code for the application on my GitHub, https://github.com/marcushellberg/docs-assistant.

Application overview

Here's how the application works on a high level:

  1. A user enters a query in the browser.
  2. Moderate the input to ensure it adheres to the content policy.
  3. Find the parts of the documentation most relevant to answering the question by querying the Pinecone vector database.
  4. Construct a ChatGPT completion query with a prompt, the relevant documentation, and the chat history. Count tokens to ensure maximal usage of the context size without exceeding it.
  5. Stream the response back to the user.

We'll create the web UI in the next part of the series. In this part, we'll focus on the service orchestrating all the moving pieces.

Create a service for handling chat completion requests

The core of the application is a Spring service class, DocsAssistantService.java. It orchestrates the completion request, counts tokens, delegating to OpenAIService and PineconeService to access outside APIs.

Start by creating the service, defining token limits, and injecting the needed services. Instantiate the tokenizer library with the correct encoding type.

@Service
public class DocsAssistantService {

    private final int MAX_TOKENS = 4096;
    private final int MAX_RESPONSE_TOKENS = 1024;
    private final int MAX_CONTEXT_TOKENS = 1536;

    private final OpenAIService openAIService;
    private final PineconeService pineconeService;
    private final Encoding tokenizer;

    private final List<Framework> supportedFrameworks = List.of(
            new Framework("Flow", "flow"),
            new Framework("Hilla with React", "hilla-react"),
            new Framework("Hilla with Lit", "hilla-lit")
    );

    public DocsAssistantService(OpenAIService openAIService, PineconeService pineconeService) {
        this.openAIService = openAIService;
        this.pineconeService = pineconeService;
        EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
        tokenizer = registry.getEncoding(EncodingType.CL100K_BASE);
    }

    public List<Framework> getSupportedFrameworks() {
        return supportedFrameworks;
    }

}

The code above defines the three frameworks and their corresponding namespaces in the Pinecone vector database.

Call ChatGPT with custom documents as context

Next, create a method that orchestrates the entire process. It returns a Flux<ChatCompletionMessage> so we can stream the response to the user's browser.

public Flux<String> getCompletionStream(List<ChatCompletionMessage> history, String framework) {
if (history.isEmpty()) {
return Flux.error(new RuntimeException("History is empty"));
}

var question = history.get(history.size() - 1).getContent();

return openAIService
.moderate(history)
.flatMap(isContentSafe -> isContentSafe ?
openAIService.createEmbedding(question) :
Mono.error(new RuntimeException("Failed to get embedding")))
.flatMap(embedding -> pineconeService.findSimilarDocuments(embedding, 10, framework))
.map(similarDocuments -> getPromptWithContext(history, similarDocuments, framework))
.flatMapMany(openAIService::generateCompletionStream);
}

Here's what the code does:

  1. It extracts the question from the message history.
  2. It moderates the entire message history (the user could have changed prior messages in their browser).
  3. If the messages pass moderation, it creates a vector embedding from the question.
  4. It then fetches up to 10 most closely related documentation snippets from Pinecone, using the selected framework as the namespace.
  5. It then creates the prompt and includes as many documentation snippets as possible into the allotted tokens.
  6. Finally, it calls ChatGPT through the OpenAI service and streams the response.

Prompt engineering

The getPromptWithContext method is responsible for constructing the list of messages to send to ChatGPT, including the system prompt, the documentation we want to pass in as context, and history.

private List<ChatCompletionMessage> getPromptWithContext(List<ChatCompletionMessage> history, List<String> contextDocs, String framework) {
var contextString = getContextString(contextDocs);
var systemMessages = new ArrayList<ChatCompletionMessage>();
var fullFramework = supportedFrameworks.stream()
.filter(f -> f.getValue().equals(framework))
.findFirst()
.orElseThrow()
.getLabel();

systemMessages.add(new ChatCompletionMessage(
ChatCompletionMessage.Role.SYSTEM,
String.format("You are a senior Vaadin expert. You love to help developers! Answer the user's question about %s development with the help of the information in the provided documentation.", fullFramework)
));
systemMessages.add(new ChatCompletionMessage(
ChatCompletionMessage.Role.USER,
String.format(
"""
Here is the documentation:
===
%s
===
"""
, contextString)
));
systemMessages.add(new ChatCompletionMessage(
ChatCompletionMessage.Role.USER,
"""
You must also follow the below rules when answering:
- Prefer splitting your response into multiple paragraphs
- Output as markdown
- Always include code snippets if available
"""

));


return capMessages(systemMessages, history);

The three most important things to pay attention to in the prompt:

  1. We instruct the LLM to act as an expert and answer questions about our selected framework.
  2. We pass in the relevant documentation in a block delimited by ===.
  3. We instruct the LLM to answer in markdown and include code snippets.

Counting tokens

Counting tokens is a critical step to getting the most relevant results. In this case, we're working with a 4096 token context, and we have defined how many tokens we want to use for the documentation snippets (1536) and how many tokens we want to reserve for the response (1024).

When converting the documentation snippets into a string, we only include up to MAX_CONTEXT_TOKENS worth of the most relevant docs:

private String getContextString(List<String> contextDocs) {
var tokenCount = 0;
var stringBuilder = new StringBuilder();
for (var doc : contextDocs) {
tokenCount += tokenizer.encode(doc + "\n---\n").size();
if (tokenCount > MAX_CONTEXT_TOKENS) {
break;
}
stringBuilder.append(doc);
stringBuilder.append("\n---\n");
}

return stringBuilder.toString();
}

Finally, we remove old message history as needed to ensure we have 1024 tokens available for the response.

/**
* Removes old messages from the history until the total number of tokens + MAX_RESPONSE_TOKENS stays under MAX_TOKENS
*
* @param systemMessages The system messages including context and prompt
* @param history The history of messages. The last message is the user question, do not remove it.
* @return The capped messages that can be sent to the OpenAI API.
*/

private List<ChatCompletionMessage> capMessages(List<ChatCompletionMessage> systemMessages,
List<ChatCompletionMessage> history) {
var availableTokens = MAX_TOKENS - MAX_RESPONSE_TOKENS;
var cappedHistory = new ArrayList<>(history);

var tokens = getTokenCount(systemMessages) + getTokenCount(cappedHistory);

while (tokens > availableTokens) {
if (cappedHistory.size() == 1) {
throw new RuntimeException("Cannot cap messages further, only user question left");
}

cappedHistory.remove(0);
tokens = getTokenCount(systemMessages) + getTokenCount(cappedHistory);
}

var cappedMessages = new ArrayList<>(systemMessages);
cappedMessages.addAll(cappedHistory);

return cappedMessages;
}


/**
* Returns the number of tokens in the messages.
* See https://github.com/openai/openai-cookbook/blob/834181d5739740eb8380096dac7056c925578d9a/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* @param messages The messages to count the tokens of
* @return The number of tokens in the messages
*/

private int getTokenCount(List<ChatCompletionMessage> messages) {
var tokenCount = 3; // every reply is primed with <|start|>assistant<|message|>
for (var message : messages) {
tokenCount += getMessageTokenCount(message);
}
return tokenCount;
}

/**
* Returns the number of tokens in the message.
*
* @param message The message to count the tokens of
* @return The number of tokens in the message
*/

private int getMessageTokenCount(ChatCompletionMessage message) {
var tokens = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n

tokens += tokenizer.encode(message.getRole().toString()).size();
tokens += tokenizer.encode(message.getContent()).size();

return tokens;
}

Next steps

In the following article in the Building an AI chatbot in Java series, we will use the service we just built to create a React UI for the AI chatbot, so users have a convenient way of interacting with the application.

Marcus Hellberg
Marcus Hellberg
Marcus is the VP of Developer Relations at Vaadin. His daily work includes everything from writing blogs and tech demos to attending events and giving presentations on all things Vaadin and web-related. You can reach out to him on Twitter @marcushellberg.
Other posts by Marcus Hellberg