How would you implement Tool Calling / Function Calling?
While I was able to make it possible, I want to share this on one hand and start a discussion on the other hand. While building function calling, I had several thing in mind:
- Using just the
Flux from Spring’s WebClient wasn’t enough. Sending just the completion chunks to the client also now requires to inform the user that a function is called. => Expanding the DTO for endpoint communication.
- I also need to react to function calling on server-side. By design the original stream from the LLM stops and does contain the info what function to be called and depending on the moon cycle a completion for the user like “I need to call this tool for that”. Then the tool needs to be executed and the result must be added to a new cycle calling the chat completion endpoint. => The stream needs to be decoupled. I created a
reactor.core.publisher.Sinks.Many for the EndpointSubscription and am now able to send chunks from many completion responses and further I can update the UI while executing the function (progress and info).
- Support both non- and streaming completions.
- Support zero to multiple functions called in the same completion.
- Functions can be implemented easily by using an interface:
ToolImpementation.
We are using Qwen2.5-14B with VLLM running on a A6000.
Find here my CompletionService:
@Service
public class CompletionService {
/**
* sending the to hilla client
*/
public static class ClientCompletionResponse {
private @Nonnull Boolean isError = false;
private String errorMessage;
private @Nonnull Boolean isTool = false;
private String toolName;
private Object toolData;
private String toolError;
private String chunk;
private @Nonnull Boolean isResoning = true;
private @Nonnull Boolean isStream = true;
private @Nonnull Boolean botHasSource = false;
private ChatCompletionChunkResponse chunkObject;
// getter, setter
}
/**
* a tool that can be offered to the LLM
*/
public static class ToolEntry {
private String type = "function";
private CompletionTool function;
public ToolEntry(CompletionTool function) {
this.function = function;
}
// getter, setter
}
public static class CompletionTool {
private String name;
private String description;
private ToolParameter parameters;
// getter, setter
public static class ToolParameter {
private String type;
private List<String> required;
private Map<String, ToolProperty> properties = new LinkedHashMap<>();
public ToolParameter() {
}
public ToolParameter(String type, List<String> required, Map<String, ToolProperty> properties) {
this.type = type;
this.required = required;
this.properties = properties;
}
// getter, setter
}
public static interface ToolProperty {
public String getType();
public void setType(String type);
public String getDescription();
public void setDescription(String description);
}
public static class ToolStringProperty implements ToolProperty {
private String type;
private String description;
// getter, setter
}
public static class ToolEnumProperty extends ToolStringProperty {
@JsonProperty("enum")
private List<String> enumValues = new ArrayList<>();
// getter, setter
}
}
private final Logger logger = LoggerFactory.getLogger(CompletionService.class);
// WebClient
private WebClient completionWebClient;
@Value("${company.tool.enableCompletionTools}")
private Boolean enableCompletionTools;
@Value("${company.tool.openweather.appid}")
private String toolOpenWeatherAppId;
private final Map<String, ToolImpementation> tools = new HashMap<>();
private FileGenerationService fileGenerationService;
public CompletionService(FileGenerationService fileGenerationService) {
this.fileGenerationService = fileGenerationService;
}
@PostConstruct
void init() {
this.completionWebClient = createWebClient();
registerTools();
}
public WebClient createWebClient() {
var client = HttpClient.create().responseTimeout(Duration.ofSeconds(60));
return WebClient.builder().clientConnector(new ReactorClientHttpConnector(client))
.defaultHeader("Content-Type", MediaType.APPLICATION_JSON_VALUE)
.build();
}
private void registerTools() {
ToolImpementation[] toolToOffer = new ToolImpementation[] { new WeatherTool(toolOpenWeatherAppId),
new MarkdownToPowerPointTool(fileGenerationService), new KermitApiTool() };
for (ToolImpementation t : toolToOffer) {
List<String> toolNames = t.getToolNames();
for (String functionName : toolNames) {
tools.put(functionName, t);
}
}
}
public void order(CompletionOrder completionOrder, Many<ClientCompletionResponse> clientSink, UserInfo userInfo) {
try {
// setup tools
LinkedHashMap<String, Object> payload = (LinkedHashMap<String, Object>) completionOrder.getPayload();
if (Boolean.TRUE.equals(this.enableCompletionTools)) {
ToolEntry[] toolEntries = this.tools.entrySet().stream()
.map(entry -> entry.getValue().createToolEntry(entry.getKey())).toArray(ToolEntry[]::new);
payload.put("tools", toolEntries);
}
Flux<ChatCompletionChunkResponse> completionFlux = completionWebClient.post()
.uri(completionOrder.getCompletionUrl())
// payload from backend
.bodyValue(payload)
.retrieve()
.bodyToFlux(ChatCompletionChunkResponse.class);
completionFlux.timeout(Duration.ofSeconds(60L), Mono.error(new ReadTimeoutException("Timeout")));
completionFlux.onErrorComplete(error -> {
// The stream terminates with a `[DONE]` message, which causes a serialization
// error
// Ignore this error and return an empty stream instead
if (error.getMessage() == null || error.getMessage().contains("JsonToken.START_ARRAY")) {
return true;
}
// If the error is not caused by the `[DONE]` message, return the error
else {
clientSink.tryEmitError(error);
return false;
}
});
final Disposable disposer = completionFlux
.subscribe(e -> this.onCompletionNext(e, completionOrder, clientSink, userInfo), e -> {
});
clientSink.asFlux().doOnComplete(disposer::dispose).subscribe();
} catch (Exception e) {
logger.error(e.getMessage(), e);
clientSink.tryEmitComplete();
}
}
private void onCompletionNext(ChatCompletionChunkResponse resp, CompletionOrder completionOrder,
Many<ClientCompletionResponse> clientSink, UserInfo userInfo) {
boolean isStream = "chat.completion.chunk".equals(resp.getObject());
final var clientResp = new ClientCompletionResponse();
clientResp.setChunkObject(resp);
clientResp.botHasSource = completionOrder.getHasSources();
for (final var choice : resp.getChoices()) {
String finish = choice.getFinishReason();
Delta delta = isStream ? choice.getDelta() : choice.getMessage();
if (delta != null) {
final String content = delta.getContent();
logger.debug("Received chunk: {}", content);
final String reasoningContent = delta.getReasoningContent();
final List<ToolCall> toolCalls = delta.getTool_calls();
clientResp.setChunk(content != null ? content : reasoningContent);
clientResp.setIsResoning(reasoningContent != null);
if (toolCalls != null && !toolCalls.isEmpty()) {
this.processToolCalls(toolCalls, resp, clientSink);
}
if (isStream && finish == null) {
clientSink.tryEmitNext(clientResp);
}
}
if ("tool_calls".equalsIgnoreCase(finish)) {
this.onToolCallStop(resp, completionOrder, clientResp, clientSink, userInfo);
} else if (finish != null) {
// default is "finish_reason":"stop"
clientSink.tryEmitNext(clientResp);
// TODO when mutliple choices supported, this cannot stop here
clientSink.tryEmitComplete();
}
// break everytime after first choice - ignore others
break;
}
}
private Map<String, List<ToolCall>> toolsPayloadBuffer = new HashMap<>();
private void processToolCalls(List<ToolCall> toolCalls, ChatCompletionChunkResponse resp,
Many<ClientCompletionResponse> clientSink) {
if (!toolsPayloadBuffer.containsKey(resp.getId())) {
for (ToolCall toolCall : toolCalls) {
final var clientResp = new ClientCompletionResponse();
clientResp.setIsTool(true);
clientResp.setToolName(toolCall.getFunction().getName());
clientResp.setIsStream(false);
clientSink.tryEmitNext(clientResp);
}
}
toolsPayloadBuffer.compute(resp.getId(), (key, value) -> updateBufferedToolCalls(toolCalls, value));
}
private List<ToolCall> updateBufferedToolCalls(List<ToolCall> newToolCalls, List<ToolCall> currentToolCalls) {
if (currentToolCalls == null) {
currentToolCalls = newToolCalls;
} else {
for (ToolCall toolCall : newToolCalls) {
Optional<ToolCall> currentValue = currentToolCalls.stream()
.filter(t -> Objects.equals(toolCall.getIndex(), t.getIndex())).findFirst();
if (currentValue.isEmpty()) {
// the initial call will not be streamed, including id, type, index and
// function.name
// when not streaming arguments are filled
currentToolCalls.add(toolCall);
} else if (toolCall.getFunction() != null && toolCall.getFunction().getArguments() != null) {
// update function.arguments on streaming
ToolCall target = currentValue.get();
String currentArguments = target.getFunction().getArguments();
if (currentArguments == null) {
currentArguments = "";
}
String newArguments = toolCall.getFunction().getArguments();
target.getFunction().setArguments(currentArguments + newArguments);
}
}
}
return currentToolCalls;
}
/**
* execute tools
*/
private void onToolCallStop(ChatCompletionChunkResponse resp, CompletionOrder completionOrder,
ClientCompletionResponse clientResp, Many<ClientCompletionResponse> clientSink, UserInfo userInfo) {
List<ToolCall> toolCalls = this.toolsPayloadBuffer.remove(resp.getId());
LinkedHashMap<String, Object> payload = (LinkedHashMap<String, Object>) completionOrder.getPayload();
ArrayList<Delta> messages = (ArrayList<Delta>) payload.get("messages");
// required: build assistant message with tool calls to reference the ids
Delta toolRequestDelta = new Delta();
toolRequestDelta.setRole("assistant");
toolRequestDelta.setTool_calls(toolCalls);
messages.add(toolRequestDelta);
for (ToolCall toolCall : toolCalls) {
ToolImpementation toolImpl = this.tools.get(toolCall.getFunction().getName());
if (toolImpl != null) {
try {
ClientCompletionResponse toolClientResponse = new ClientCompletionResponse();
toolClientResponse.isTool = true;
toolClientResponse.toolName = toolCall.getFunction().getName();
Delta toolResponseDelta = new Delta();
String content = toolImpl.buildToolResponse(toolClientResponse, toolCall, userInfo);
toolResponseDelta.setContent(content);
toolResponseDelta.setRole("tool");
toolResponseDelta.setTool_call_id(toolCall.getId());
messages.add(toolResponseDelta);
clientSink.tryEmitNext(toolClientResponse);
} catch (Exception e) {
clientSink.tryEmitError(e);
logger.error("Error on tool " + toolCall.getFunction().getName(), e);
}
} else {
var error = new UnsupportedOperationException(
"tool %s is cannot be processed".formatted(toolCall.getFunction().getName()));
clientSink.tryEmitError(error);
throw error;
}
}
// send new request with tool response
order(completionOrder, clientSink, userInfo);
}
}
How the CompletionSerivce gets called in the endpoint:
public @Nonnull EndpointSubscription<@Nonnull ClientCompletionResponse> subscribeOnCompletion(
@Nonnull String preCacheKey)
throws EndpointAccessDeniedException {
if (userInfoService.getUserInfo() == null) {
throw new EndpointAccessDeniedException(NOT_AUTHENTICATED);
}
CompletionOrder completionOrder = this.cache.getCompletionOrder(preCacheKey);
Sinks.Many<ClientCompletionResponse> sink = Sinks.many().multicast().onBackpressureBuffer();
this.completionService.order(completionOrder, sink, userInfoService.getUserInfo());
Flux<ClientCompletionResponse> clientFlux = sink.asFlux();
// remove from queue when flux ended
clientFlux.doFinally(e -> cancelQueue(preCacheKey)).subscribe();
return EndpointSubscription.of(clientFlux, () -> cancelQueue(preCacheKey));
}
public class CompletionOrder {
private String completionUrl; // the backend decides what server/model to use based on the selected assistant
private Object payload; // the completion payload gets build by the backend.
}
The interface for Tools:
public interface ToolImpementation {
List<String> getToolNames();
ToolEntry createToolEntry(String name);
String buildToolResponse(ClientCompletionResponse currentClientResponse, ToolCall toolCall, UserInfo userInfo) throws Exception;
}
The completion chunk supporting non-stream and streaming completions:
public class ChatCompletionChunkResponse {
private String id;
private String model;
private String object;
private List<Choice> choices;
// getter, setter
public static class Choice {
private Delta message;
private Delta delta;
@JsonProperty("finish_reason")
private String finishReason;
@JsonProperty("stop_reason")
private String stopReason;
// getter, setter
}
public static class Delta {
private String content;
@JsonProperty("reasoning_content")
private String reasoningContent;
private String role;
@JsonProperty("tool_calls")
private List<ToolCall> tool_calls;
private String tool_call_id;
// getter, setter
}
public static class ToolCall {
private String id;
private String type;
private String index;
private Function function;
// getter, setter
}
public static class Function {
private String name;
private String arguments;
// getter, setter
}
}