1use crate::{
2 azure_chatbot::{LLMRequest, ToolCallType},
3 chatbot_error::ChatbotResult,
4 prelude::*,
5};
6use core::default::Default;
7use headless_lms_models::{
8 application_task_default_language_models::TaskLMSpec,
9 chatbot_conversation_message_tool_calls::ChatbotConversationMessageToolCall,
10 chatbot_conversation_message_tool_outputs::ChatbotConversationMessageToolOutput,
11 chatbot_conversation_messages::{ChatbotConversationMessage, MessageRole},
12};
13use headless_lms_utils::ApplicationConfiguration;
14use reqwest::Response;
15use reqwest::header::HeaderMap;
16use serde::{Deserialize, Serialize};
17use tracing::{debug, error, instrument, trace, warn};
18
19pub const LLM_API_VERSION: &str = "2024-10-21";
21
22#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct APIMessage {
25 pub role: MessageRole,
26 #[serde(flatten)]
27 pub fields: APIMessageKind,
28}
29
30impl APIMessage {
31 pub fn to_chatbot_conversation_message(
36 &self,
37 conversation_id: Uuid,
38 order_number: i32,
39 ) -> ChatbotResult<ChatbotConversationMessage> {
40 let res = match self.fields.clone() {
41 APIMessageKind::Text(msg) => ChatbotConversationMessage {
42 message_role: self.role,
43 conversation_id,
44 order_number,
45 message_is_complete: true,
46 used_tokens: estimate_tokens(&msg.content),
47 message: Some(msg.content),
48 ..Default::default()
49 },
50 APIMessageKind::ToolCall(msg) => {
51 let tool_call_fields = msg
52 .tool_calls
53 .iter()
54 .map(|x| ChatbotConversationMessageToolCall::try_from(x.to_owned()))
55 .collect::<ChatbotResult<Vec<_>>>()?;
56 let estimated_tokens: i32 = msg
57 .tool_calls
58 .iter()
59 .map(|x| estimate_tokens(&x.function.arguments))
60 .sum();
61 ChatbotConversationMessage {
62 message_role: self.role,
63 conversation_id,
64 order_number,
65 message_is_complete: true,
66 message: None,
67 tool_call_fields,
68 used_tokens: estimated_tokens,
69 ..Default::default()
70 }
71 }
72 APIMessageKind::ToolResponse(msg) => ChatbotConversationMessage {
73 message_role: self.role,
74 conversation_id,
75 order_number,
76 message_is_complete: true,
77 message: None,
78 used_tokens: 0,
79 tool_output: Some(ChatbotConversationMessageToolOutput::from(msg)),
80 ..Default::default()
81 },
82 };
83 Result::Ok(res)
84 }
85}
86
87impl TryFrom<ChatbotConversationMessage> for APIMessage {
88 type Error = ChatbotError;
89 fn try_from(message: ChatbotConversationMessage) -> ChatbotResult<Self> {
90 let res = match message.message_role {
91 MessageRole::Assistant => {
92 if !message.tool_call_fields.is_empty() {
93 APIMessage {
94 role: message.message_role,
95 fields: APIMessageKind::ToolCall(APIMessageToolCall {
96 tool_calls: message
97 .tool_call_fields
98 .iter()
99 .map(|f| APIToolCall::from(f.clone()))
100 .collect(),
101 }),
102 }
103 } else if let Some(msg) = message.message {
104 APIMessage {
105 role: message.message_role,
106 fields: APIMessageKind::Text(APIMessageText { content: msg }),
107 }
108 } else {
109 return Err(ChatbotError::new(
110 ChatbotErrorType::InvalidMessageShape,
111 "A 'role: assistant' type ChatbotConversationMessage must have either tool_call_fields or a text message.",
112 None,
113 ));
114 }
115 }
116 MessageRole::Tool => {
117 if let Some(tool_output) = message.tool_output {
118 APIMessage {
119 role: message.message_role,
120 fields: APIMessageKind::ToolResponse(APIMessageToolResponse {
121 tool_call_id: tool_output.tool_call_id,
122 name: tool_output.tool_name,
123 content: tool_output.tool_output,
124 }),
125 }
126 } else {
127 return Err(ChatbotError::new(
128 ChatbotErrorType::InvalidMessageShape,
129 "A 'role: tool' type ChatbotConversationMessage must have field tool_output.",
130 None,
131 ));
132 }
133 }
134 MessageRole::User => APIMessage {
135 role: message.message_role,
136 fields: APIMessageKind::Text(APIMessageText {
137 content: message.message.unwrap_or_default(),
138 }),
139 },
140 MessageRole::System => {
141 return Err(ChatbotError::new(
142 ChatbotErrorType::InvalidMessageShape,
143 "A 'role: system' type ChatbotConversationMessage cannot be saved into the database.",
144 None,
145 ));
146 }
147 };
148 Result::Ok(res)
149 }
150}
151
152#[derive(Serialize, Deserialize, Debug, Clone)]
153#[serde(untagged)]
154pub enum APIMessageKind {
155 Text(APIMessageText),
156 ToolCall(APIMessageToolCall),
157 ToolResponse(APIMessageToolResponse),
158}
159
160#[derive(Serialize, Deserialize, Debug, Clone)]
162pub struct APIMessageText {
163 pub content: String,
164}
165
166#[derive(Serialize, Deserialize, Debug, Clone)]
169pub struct APIMessageToolCall {
170 pub tool_calls: Vec<APIToolCall>,
171}
172
173#[derive(Serialize, Deserialize, Debug, Clone)]
175pub struct APIMessageToolResponse {
176 pub tool_call_id: String,
177 pub name: String,
178 pub content: String,
179}
180
181impl From<ChatbotConversationMessageToolOutput> for APIMessageToolResponse {
182 fn from(value: ChatbotConversationMessageToolOutput) -> Self {
183 APIMessageToolResponse {
184 tool_call_id: value.tool_call_id,
185 name: value.tool_name,
186 content: value.tool_output,
187 }
188 }
189}
190
191impl From<APIMessageToolResponse> for ChatbotConversationMessageToolOutput {
192 fn from(value: APIMessageToolResponse) -> Self {
193 ChatbotConversationMessageToolOutput {
194 tool_name: value.name,
195 tool_output: value.content,
196 tool_call_id: value.tool_call_id,
197 ..Default::default()
198 }
199 }
200}
201
202#[derive(Serialize, Deserialize, Debug, Clone)]
204pub struct APIToolCall {
205 pub function: APITool,
206 pub id: String,
207 #[serde(rename = "type")]
208 pub tool_type: ToolCallType,
209}
210
211impl From<ChatbotConversationMessageToolCall> for APIToolCall {
212 fn from(value: ChatbotConversationMessageToolCall) -> Self {
213 APIToolCall {
214 function: APITool {
215 arguments: value.tool_arguments.to_string(),
216 name: value.tool_name,
217 },
218 id: value.tool_call_id,
219 tool_type: ToolCallType::Function,
220 }
221 }
222}
223
224impl TryFrom<APIToolCall> for ChatbotConversationMessageToolCall {
225 type Error = ChatbotError;
226 fn try_from(value: APIToolCall) -> ChatbotResult<Self> {
227 Ok(ChatbotConversationMessageToolCall {
228 tool_name: value.function.name,
229 tool_arguments: serde_json::from_str(&value.function.arguments)?,
230 tool_call_id: value.id,
231 ..Default::default()
232 })
233 }
234}
235
236#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
237pub struct APITool {
238 pub arguments: String,
239 pub name: String,
240}
241
242#[derive(Serialize, Deserialize, Debug)]
245pub struct AzureCompletionRequest {
246 #[serde(flatten)]
247 pub base: LLMRequest,
248 pub stream: bool,
249}
250
251#[derive(Deserialize, Debug)]
253pub struct LLMCompletionResponse {
254 pub choices: Vec<LLMChoice>,
255}
256
257#[derive(Deserialize, Debug)]
258pub struct LLMChoice {
259 pub message: APIMessage,
260}
261
262#[instrument(skip(api_key), fields(api_key_length = api_key.len()))]
264pub fn build_llm_headers(api_key: &str) -> anyhow::Result<HeaderMap> {
265 trace!("Building LLM request headers");
266 let mut headers = HeaderMap::new();
267 headers.insert(
268 "api-key",
269 api_key.parse().map_err(|_e| {
270 error!("Failed to parse API key");
271 anyhow::anyhow!("Invalid API key")
272 })?,
273 );
274 headers.insert(
275 "content-type",
276 "application/json".parse().map_err(|_e| {
277 error!("Failed to parse content-type header");
278 anyhow::anyhow!("Internal error")
279 })?,
280 );
281 trace!("Successfully built headers");
282 Ok(headers)
283}
284
285#[instrument(skip(endpoint))]
287pub fn prepare_azure_endpoint(mut endpoint: url::Url) -> url::Url {
288 trace!(
289 "Preparing Azure endpoint with API version {}",
290 LLM_API_VERSION
291 );
292 endpoint.set_query(Some(&format!("api-version={}", LLM_API_VERSION)));
294 trace!("Endpoint prepared: {}", endpoint);
295 endpoint
296}
297
298#[instrument(skip(text), fields(text_length = text.len()))]
300pub fn estimate_tokens(text: &str) -> i32 {
301 trace!("Estimating tokens for text");
302 let text_length = text.chars().fold(0, |acc, c| {
303 let mut len = c.len_utf8() as i32;
304 if len > 1 {
305 len *= 2;
307 }
308 if c.is_ascii_punctuation() {
309 len *= 2;
311 }
312 acc + len
313 });
314 let estimated_tokens = text_length / 4;
316 trace!("Estimated {} tokens for text", estimated_tokens);
317 estimated_tokens
318}
319
320#[instrument(skip(chat_request, endpoint, api_key), fields(
322 num_messages = chat_request.messages.len(),
323 temperature,
324 max_tokens,
325 endpoint = %endpoint
326))]
327async fn make_llm_request(
328 chat_request: LLMRequest,
329 endpoint: &url::Url,
330 api_key: &str,
331) -> anyhow::Result<LLMCompletionResponse> {
332 debug!(
333 "Preparing LLM request with {} messages",
334 chat_request.messages.len()
335 );
336
337 trace!("Base request: {:?}", chat_request);
338
339 let request = AzureCompletionRequest {
340 base: chat_request,
341 stream: false,
342 };
343
344 let headers = build_llm_headers(api_key)?;
345 debug!("Sending request to LLM endpoint: {}", endpoint);
346
347 let response = REQWEST_CLIENT
348 .post(prepare_azure_endpoint(endpoint.clone()))
349 .headers(headers)
350 .json(&request)
351 .send()
352 .await?;
353
354 trace!("Received response from LLM");
355 process_llm_response(response).await
356}
357
358#[instrument(skip(response), fields(status = %response.status()))]
360async fn process_llm_response(response: Response) -> anyhow::Result<LLMCompletionResponse> {
361 if !response.status().is_success() {
362 let status = response.status();
363 let error_text = response.text().await?;
364 error!(
365 status = %status,
366 error = %error_text,
367 "Error calling LLM API"
368 );
369 return Err(anyhow::anyhow!(
370 "Error calling LLM API: Status: {}. Error: {}",
371 status,
372 error_text
373 ));
374 }
375
376 trace!("Processing successful LLM response");
377 let completion: LLMCompletionResponse = response.json().await?;
379 debug!(
380 "Successfully processed LLM response with {} choices",
381 completion.choices.len()
382 );
383 Ok(completion)
384}
385
386#[instrument(skip(chat_request, app_config), fields(
388 num_messages = chat_request.messages.len(),
389 temperature,
390 max_tokens
391))]
392pub async fn make_streaming_llm_request(
393 chat_request: LLMRequest,
394 model_deployment_name: &str,
395 app_config: &ApplicationConfiguration,
396) -> anyhow::Result<Response> {
397 debug!(
398 "Preparing streaming LLM request with {} messages",
399 chat_request.messages.len()
400 );
401 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
402 error!("Azure configuration missing");
403 anyhow::anyhow!("Azure configuration is missing from the application configuration")
404 })?;
405
406 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
407 error!("Chatbot configuration missing");
408 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
409 })?;
410
411 trace!("Base request: {:?}", chat_request);
412
413 let request = AzureCompletionRequest {
414 base: chat_request,
415 stream: true,
416 };
417
418 let headers = build_llm_headers(&chatbot_config.api_key)?;
419 let api_endpoint = chatbot_config
420 .api_endpoint
421 .join(&(model_deployment_name.to_owned() + "/chat/completions"))?;
422 debug!(
423 "Sending streaming request to LLM endpoint: {}",
424 api_endpoint
425 );
426
427 let response = REQWEST_CLIENT
428 .post(prepare_azure_endpoint(api_endpoint.clone()))
429 .headers(headers)
430 .json(&request)
431 .send()
432 .await?;
433
434 if !response.status().is_success() {
435 let status = response.status();
436 let error_text = response.text().await?;
437 error!(
438 status = %status,
439 error = %error_text,
440 "Error calling streaming LLM API"
441 );
442 return Err(anyhow::anyhow!(
443 "Error calling LLM API: Status: {}. Error: {}",
444 status,
445 error_text
446 ));
447 }
448
449 debug!("Successfully initiated streaming response");
450 Ok(response)
451}
452
453#[instrument(skip(chat_request, app_config, task_lm), fields(
455 num_messages = chat_request.messages.len(),
456 temperature,
457 max_tokens
458))]
459pub async fn make_blocking_llm_request(
460 chat_request: LLMRequest,
461 app_config: &ApplicationConfiguration,
462 task_lm: &TaskLMSpec,
463) -> anyhow::Result<LLMCompletionResponse> {
464 debug!(
465 "Preparing blocking LLM request with {} messages",
466 chat_request.messages.len()
467 );
468 let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
469 error!("Azure configuration missing");
470 anyhow::anyhow!("Azure configuration is missing from the application configuration")
471 })?;
472
473 let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
474 error!("Chatbot configuration missing");
475 anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
476 })?;
477
478 let model = task_lm.deployment_name.to_owned();
479 let path = model + "/chat/completions";
480
481 let api_endpoint = chatbot_config.api_endpoint.join(&path)?;
482
483 trace!("Making LLM request to endpoint: {}", api_endpoint);
484 make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
485}
486
487pub fn parse_text_completion(completion: LLMCompletionResponse) -> ChatbotResult<String> {
490 let res =
491 completion
492 .choices
493 .into_iter()
494 .map(|x| match x.message.fields {
495 APIMessageKind::Text(message) => Ok(message.content),
496 _ => Err(ChatbotError::new( ChatbotErrorType::InvalidMessageShape, "It was assumed this LLM response contains only text, but a tool call or tool response was detected.", None)),
497 })
498 .collect::<ChatbotResult<Vec<String>>>()?
499 .join("");
500 if res.is_empty() {
501 return Err(ChatbotError::new(
502 ChatbotErrorType::InvalidMessageShape,
503 "No content returned from LLM",
504 None,
505 ));
506 };
507 Ok(res)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_estimate_tokens() {
516 assert_eq!(estimate_tokens("Hello, world!"), 3);
518 assert_eq!(estimate_tokens(""), 0);
519 assert_eq!(
521 estimate_tokens("This is a longer sentence with several words."),
522 11
523 );
524 assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
526 assert_eq!(estimate_tokens("トークンは楽しい"), 12);
528 assert_eq!(
530 estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
531 48
532 );
533 assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
535 assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
537 }
538}