Skip to main content

headless_lms_chatbot/
llm_utils.rs

1use secrecy::{ExposeSecret, SecretString};
2
3use crate::{
4    azure_chatbot::{
5        ChatResponse, InputItem, LLMRequest, LLMRequestParams, MistralParams, NonThinkingParams,
6        OutputItem, Reasoning, ReasoningOutput, SummaryType, ThinkingParams,
7    },
8    chatbot_error::ChatbotResult,
9    prelude::*,
10};
11use core::default::Default;
12use headless_lms_base::config::ApplicationConfiguration;
13use headless_lms_models::{
14    chatbot_configurations::{ChatbotConfiguration, ReasoningEffortLevel},
15    chatbot_configurations_models::{ChatbotConfigurationModel, ModelType},
16    chatbot_conversation_message_messages::{ChatbotConversationMessageMessage, MessageRole},
17    chatbot_conversation_message_reasoning::ChatbotConversationMessageReasoning,
18    chatbot_conversation_message_tool_calls::{ChatbotConversationMessageToolCall, ToolKind},
19    chatbot_conversation_message_tool_outputs::ChatbotConversationMessageToolOutput,
20    chatbot_conversation_messages::{ChatbotConversationMessage, Message},
21};
22use reqwest::Response;
23use reqwest::header::HeaderMap;
24use serde::{Deserialize, Serialize};
25use tracing::{debug, error, instrument, trace, warn};
26
27/// Common message structure used for LLM API requests
28#[derive(Serialize, Deserialize, Debug, Clone)]
29pub struct APIOutputMessage {
30    #[serde(flatten)]
31    pub message_type: OutputItem,
32}
33
34/// Common message structure used for LLM API requests
35#[derive(Serialize, Deserialize, Debug, Clone)]
36pub struct APIInputMessage {
37    #[serde(flatten)]
38    pub message_type: InputItem,
39}
40
41impl TryFrom<APIOutputMessage> for APIInputMessage {
42    type Error = ChatbotError;
43    fn try_from(message: APIOutputMessage) -> Result<Self, Self::Error> {
44        match message.message_type {
45            OutputItem::Message {
46                role,
47                content,
48                response_id: _response_id,
49            } => Ok(APIInputMessage {
50                message_type: InputItem::Message { role, content },
51            }),
52            OutputItem::FunctionCall {
53                call_id,
54                tool_name,
55                arguments,
56                response_id: _response_id,
57            } => Ok(APIInputMessage {
58                message_type: InputItem::FunctionCall {
59                    call_id,
60                    tool_name,
61                    arguments,
62                },
63            }),
64            OutputItem::FunctionCallOutput {
65                call_id,
66                output,
67                response_id: _response_id,
68            } => Ok(APIInputMessage {
69                message_type: InputItem::FunctionCallOutput { call_id, output },
70            }),
71            OutputItem::AzureAiSearchCall {
72                call_id,
73                arguments,
74                response_id: _response_id,
75            } => Ok(APIInputMessage {
76                message_type: InputItem::FunctionCall {
77                    call_id,
78                    tool_name: "azure_ai_search".to_string(),
79                    arguments,
80                },
81            }),
82            OutputItem::AzureAiSearchCallOutput {
83                call_id,
84                output,
85                response_id: _response_id,
86            } => Ok(APIInputMessage {
87                message_type: InputItem::FunctionCallOutput { call_id, output },
88            }),
89            OutputItem::Reasoning { .. } => {
90                Err(chatbot_err!(Other, "Reasoning input items not allowed."))
91            }
92        }
93    }
94}
95
96impl TryFrom<ChatbotConversationMessage> for APIInputMessage {
97    type Error = ChatbotError;
98
99    fn try_from(message: ChatbotConversationMessage) -> Result<Self, Self::Error> {
100        let res = match message.message {
101            Message::Text(text_message) => match text_message.message_role {
102                MessageRole::User | MessageRole::Assistant => APIInputMessage {
103                    message_type: InputItem::Message {
104                        role: text_message.message_role,
105                        content: MessageContent::Text(text_message.text),
106                    },
107                },
108                _ => {
109                    return Err(chatbot_err!(
110                        InvalidMessageShape,
111                        "A 'role: system' or 'role: developer' type text-variant ChatbotConversationMessage shouldn't be saved into the database."
112                    ));
113                }
114            },
115            Message::ToolCall(tool_call) => match tool_call.tool_kind {
116                ToolKind::Function => APIInputMessage {
117                    message_type: InputItem::FunctionCall {
118                        call_id: tool_call.tool_call_id,
119                        tool_name: tool_call.tool_name,
120                        arguments: serde_json::to_string(&tool_call.tool_arguments)?,
121                    },
122                },
123                ToolKind::AzureAiSearch => APIInputMessage {
124                    message_type: InputItem::FunctionCall {
125                        call_id: tool_call.tool_call_id,
126                        tool_name: "azure_ai_search".to_string(),
127                        arguments: serde_json::to_string(&tool_call.tool_arguments)?,
128                    },
129                },
130            },
131            Message::ToolOutput(tool_output) => match tool_output.tool_kind {
132                ToolKind::Function => APIInputMessage {
133                    message_type: InputItem::FunctionCallOutput {
134                        call_id: tool_output.tool_call_id,
135                        output: tool_output.output,
136                    },
137                },
138                ToolKind::AzureAiSearch => APIInputMessage {
139                    message_type: InputItem::FunctionCallOutput {
140                        call_id: tool_output.tool_call_id,
141                        output: tool_output.output,
142                    },
143                },
144            },
145            Message::Reasoning(..) => {
146                // todo: can reasoning input items be allowed? if there is a summary
147                return Err(chatbot_err!(Other, "Reasoning input items not allowed."));
148            }
149        };
150        Result::Ok(res)
151    }
152}
153
154#[derive(Serialize, Deserialize, Debug, Clone)]
155#[serde(untagged)]
156pub enum MessageContent {
157    Text(String),
158    Object(Vec<ChatResponse>),
159}
160
161impl MessageContent {
162    pub fn get_content_text(self) -> String {
163        match self {
164            MessageContent::Text(msg_text) => msg_text,
165            MessageContent::Object(output) => output
166                .iter()
167                .map(|x| x.text.to_owned())
168                .collect::<Vec<String>>()
169                .join(""),
170        }
171    }
172}
173
174impl APIOutputMessage {
175    /// Create a ChatbotConversationMessage from an APIMessage to save it into the DB.
176    /// Notice that the insert operation ignores some of the fields, like timestamps.
177    /// `to_chatbot_conversation_message` doesn't set the correct order_number field
178    /// value.
179    pub fn to_chatbot_conversation_message(
180        &self,
181        conversation_id: Uuid,
182    ) -> ChatbotResult<ChatbotConversationMessage> {
183        let res = match self.message_type.clone() {
184            OutputItem::Message {
185                role,
186                content,
187                response_id,
188            } => {
189                let text = content.get_content_text();
190                let used_tokens = estimate_tokens(&text);
191
192                ChatbotConversationMessage {
193                    conversation_id,
194                    message: Message::Text(ChatbotConversationMessageMessage {
195                        text,
196                        message_role: role,
197                        message_is_complete: true,
198                        used_tokens,
199                        response_id: if role == MessageRole::User {
200                            None
201                        } else {
202                            Some(response_id)
203                        },
204                        ..Default::default()
205                    }),
206                    ..Default::default()
207                }
208            }
209            OutputItem::FunctionCall {
210                call_id,
211                tool_name,
212                arguments,
213                response_id,
214            } => ChatbotConversationMessage {
215                conversation_id,
216                message: Message::ToolCall(ChatbotConversationMessageToolCall {
217                    tool_name,
218                    tool_arguments: serde_json::to_value(arguments)?,
219                    tool_call_id: call_id,
220                    tool_kind: ToolKind::Function,
221                    response_id,
222                    ..Default::default()
223                }),
224                ..Default::default()
225            },
226            OutputItem::FunctionCallOutput {
227                call_id,
228                output,
229                response_id,
230            } => ChatbotConversationMessage {
231                conversation_id,
232                message: Message::ToolOutput(ChatbotConversationMessageToolOutput {
233                    output,
234                    tool_call_id: call_id,
235                    tool_kind: ToolKind::Function,
236                    response_id,
237                    ..Default::default()
238                }),
239                ..Default::default()
240            },
241            OutputItem::AzureAiSearchCall {
242                call_id,
243                arguments,
244                response_id,
245            } => ChatbotConversationMessage {
246                conversation_id,
247                message: Message::ToolCall(ChatbotConversationMessageToolCall {
248                    tool_arguments: serde_json::to_value(arguments)?,
249                    tool_call_id: call_id,
250                    tool_kind: ToolKind::AzureAiSearch,
251                    tool_name: "azure_ai_search".to_string(),
252                    response_id,
253                    ..Default::default()
254                }),
255                ..Default::default()
256            },
257            OutputItem::AzureAiSearchCallOutput {
258                call_id,
259                output,
260                response_id,
261            } => ChatbotConversationMessage {
262                conversation_id,
263                message: Message::ToolOutput(ChatbotConversationMessageToolOutput {
264                    tool_call_id: call_id,
265                    tool_kind: ToolKind::AzureAiSearch,
266                    output,
267                    response_id,
268                    ..Default::default()
269                }),
270                ..Default::default()
271            },
272            OutputItem::Reasoning {
273                summary,
274                response_id,
275            } => {
276                let text = if !summary.is_empty() {
277                    Some(
278                        summary
279                            .iter()
280                            .map(|i| i.text.to_owned())
281                            .collect::<Vec<String>>()
282                            .join(" "),
283                    )
284                } else {
285                    None
286                };
287                ChatbotConversationMessage {
288                    conversation_id,
289                    message: Message::Reasoning(ChatbotConversationMessageReasoning {
290                        summary: text,
291                        response_id,
292                        ..Default::default()
293                    }),
294                    ..Default::default()
295                }
296            }
297        };
298        Result::Ok(res)
299    }
300}
301
302impl TryFrom<ChatbotConversationMessage> for APIOutputMessage {
303    type Error = ChatbotError;
304
305    fn try_from(message: ChatbotConversationMessage) -> ChatbotResult<Self> {
306        let res = match message.message {
307            Message::Text(text_message) => match text_message.message_role {
308                MessageRole::User | MessageRole::Assistant => APIOutputMessage {
309                    message_type: OutputItem::Message {
310                        role: text_message.message_role,
311                        content: MessageContent::Text(text_message.text),
312                        response_id: if text_message.message_role == MessageRole::User {
313                            "".to_string()
314                        } else {
315                            text_message.response_id.ok_or(chatbot_err!(
316                                    Other,
317                                    "Can't convert ChatbotConversationMessage into APIOutputMessage: a role='assistant' message should have a response_id, but it's missing"
318                                ))?
319                        },
320                    },
321                },
322                _ => {
323                    return Err(chatbot_err!(
324                        InvalidMessageShape,
325                        "A 'role: system' or 'role: developer' type text-variant ChatbotConversationMessage shouldn't be saved into the database."
326                    ));
327                }
328            },
329            Message::ToolCall(tool_call) => match tool_call.tool_kind {
330                ToolKind::Function => APIOutputMessage {
331                    message_type: OutputItem::FunctionCall {
332                        call_id: tool_call.tool_call_id,
333                        tool_name: tool_call.tool_name,
334                        arguments: serde_json::to_string(&tool_call.tool_arguments)?,
335                        response_id: tool_call.response_id,
336                    },
337                },
338                ToolKind::AzureAiSearch => APIOutputMessage {
339                    message_type: OutputItem::AzureAiSearchCall {
340                        call_id: tool_call.tool_call_id,
341                        arguments: serde_json::to_string(&tool_call.tool_arguments)?,
342                        response_id: tool_call.response_id,
343                    },
344                },
345            },
346            Message::ToolOutput(tool_output) => match tool_output.tool_kind {
347                ToolKind::Function => APIOutputMessage {
348                    message_type: OutputItem::FunctionCallOutput {
349                        call_id: tool_output.tool_call_id,
350                        output: tool_output.output,
351                        response_id: tool_output.response_id,
352                    },
353                },
354                ToolKind::AzureAiSearch => APIOutputMessage {
355                    message_type: OutputItem::AzureAiSearchCallOutput {
356                        call_id: tool_output.tool_call_id,
357                        output: tool_output.output,
358                        response_id: tool_output.response_id,
359                    },
360                },
361            },
362            Message::Reasoning(reasoning) => {
363                if let Some(text) = reasoning.summary {
364                    APIOutputMessage {
365                        message_type: OutputItem::Reasoning {
366                            summary: vec![ReasoningOutput {
367                                output_type: "summary_text".to_string(),
368                                text,
369                            }],
370                            response_id: reasoning.response_id,
371                        },
372                    }
373                } else {
374                    APIOutputMessage {
375                        message_type: OutputItem::Reasoning {
376                            summary: vec![],
377                            response_id: reasoning.response_id,
378                        },
379                    }
380                }
381            }
382        };
383        Result::Ok(res)
384    }
385}
386
387impl From<ChatbotConversationMessageToolOutput> for APIOutputMessage {
388    fn from(value: ChatbotConversationMessageToolOutput) -> Self {
389        match value.tool_kind {
390            ToolKind::Function => APIOutputMessage {
391                message_type: OutputItem::FunctionCallOutput {
392                    call_id: value.tool_call_id,
393                    output: value.output,
394                    response_id: value.response_id,
395                },
396            },
397            ToolKind::AzureAiSearch => APIOutputMessage {
398                message_type: OutputItem::AzureAiSearchCallOutput {
399                    response_id: value.response_id,
400                    call_id: value.tool_call_id,
401                    output: value.output,
402                },
403            },
404        }
405    }
406}
407
408impl TryFrom<APIOutputMessage> for ChatbotConversationMessageToolOutput {
409    type Error = ChatbotError;
410    fn try_from(value: APIOutputMessage) -> ChatbotResult<Self> {
411        match value.message_type {
412            OutputItem::FunctionCallOutput {
413                call_id,
414                output,
415                response_id,
416            } => Ok(ChatbotConversationMessageToolOutput {
417                output,
418                tool_call_id: call_id,
419                response_id,
420                ..Default::default()
421            }),
422            OutputItem::AzureAiSearchCallOutput {
423                response_id,
424                call_id,
425                output,
426            } => Ok(ChatbotConversationMessageToolOutput {
427                output,
428                tool_call_id: call_id,
429                response_id,
430                ..Default::default()
431            }),
432            _ => Err(chatbot_err!(
433                Other,
434                "Can't convert APIMessage to ChatbotConversationMessageToolOutput: APIMessage type is not OutputItem::FunctionCallOutput"
435            )),
436        }
437    }
438}
439
440/// An LLM tool call that is part of a request to Azure
441#[derive(Serialize, Deserialize, Debug, Clone)]
442pub struct APIToolCall {
443    pub function: APITool,
444    pub id: String,
445    #[serde(rename = "type")]
446    pub tool_type: ToolKind,
447}
448
449impl From<ChatbotConversationMessageToolCall> for APIToolCall {
450    fn from(value: ChatbotConversationMessageToolCall) -> Self {
451        APIToolCall {
452            function: APITool {
453                arguments: value.tool_arguments.to_string(),
454                name: value.tool_name,
455            },
456            id: value.tool_call_id,
457            tool_type: value.tool_kind,
458        }
459    }
460}
461
462impl TryFrom<APIToolCall> for ChatbotConversationMessageToolCall {
463    type Error = ChatbotError;
464    fn try_from(value: APIToolCall) -> ChatbotResult<Self> {
465        Ok(ChatbotConversationMessageToolCall {
466            tool_name: value.function.name,
467            tool_arguments: serde_json::from_str(&value.function.arguments)?,
468            tool_call_id: value.id,
469            tool_kind: value.tool_type,
470            ..Default::default()
471        })
472    }
473}
474
475#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
476pub struct APITool {
477    pub arguments: String,
478    pub name: String,
479}
480
481/// Simple completion-focused LLM request for Azure OpenAI
482/// Note: In Azure OpenAI, the model is specified in the URL, not in the request body
483#[derive(Serialize, Deserialize, Debug)]
484pub struct AzureCompletionRequest {
485    #[serde(flatten)]
486    pub base: LLMRequest,
487    pub stream: bool,
488}
489
490/// Response from LLM for simple completions
491#[derive(Deserialize, Debug)]
492pub struct LLMResponse {
493    pub id: String,
494    pub output: Vec<APIOutputMessage>,
495}
496
497/// Builds common headers for LLM requests
498#[instrument(skip(api_key), fields(api_key_length = api_key.expose_secret().len()))]
499pub fn build_llm_headers(api_key: &SecretString) -> anyhow::Result<HeaderMap> {
500    trace!("Building LLM request headers");
501    let mut headers = HeaderMap::new();
502    headers.insert(
503        "api-key",
504        // Exposed only here, at the point the header value is constructed.
505        api_key.expose_secret().parse().map_err(|_e| {
506            error!("Failed to parse API key");
507            anyhow::anyhow!("Invalid API key")
508        })?,
509    );
510    headers.insert(
511        "content-type",
512        "application/json".parse().map_err(|_e| {
513            error!("Failed to parse content-type header");
514            anyhow::anyhow!("Internal error")
515        })?,
516    );
517    trace!("Successfully built headers");
518    Ok(headers)
519}
520
521/// Estimate the number of tokens in a given text.
522#[instrument(skip(text), fields(text_length = text.len()))]
523pub fn estimate_tokens(text: &str) -> i32 {
524    trace!("Estimating tokens for text");
525    let text_length = text.chars().fold(0, |acc, c| {
526        let mut len = c.len_utf8() as i32;
527        if len > 1 {
528            // The longer the character is, the more likely the text around is taking up more tokens
529            len *= 2;
530        }
531        if c.is_ascii_punctuation() {
532            // Punctuation is less common and is thus less likely to be part of a token
533            len *= 2;
534        }
535        acc + len
536    });
537    // A token is roughly 4 characters
538    let estimated_tokens = text_length / 4;
539    trace!("Estimated {} tokens for text", estimated_tokens);
540    estimated_tokens
541}
542
543/// Makes a non-streaming request to an LLM
544#[instrument(skip(chat_request, endpoint, api_key), fields(
545    num_messages = chat_request.input.len(),
546    temperature,
547    max_tokens,
548    endpoint = %endpoint
549))]
550async fn make_llm_request(
551    chat_request: LLMRequest,
552    endpoint: &url::Url,
553    api_key: &SecretString,
554) -> anyhow::Result<LLMResponse> {
555    debug!(
556        "Preparing LLM request with {} messages",
557        chat_request.input.len()
558    );
559
560    trace!("Base request: {:?}", chat_request);
561
562    let request = AzureCompletionRequest {
563        base: chat_request,
564        stream: false,
565    };
566
567    let headers = build_llm_headers(api_key)?;
568    debug!("Sending request to LLM endpoint: {}", endpoint);
569
570    let response = REQWEST_CLIENT
571        .post(endpoint.clone())
572        .headers(headers)
573        .json(&request)
574        .send()
575        .await?;
576
577    trace!("Received response from LLM");
578    process_llm_response(response).await
579}
580
581/// Process a non-streaming LLM response
582#[instrument(skip(response), fields(status = %response.status()))]
583async fn process_llm_response(response: Response) -> anyhow::Result<LLMResponse> {
584    if !response.status().is_success() {
585        let status = response.status();
586        let error_text = response.text().await?;
587        error!(
588            status = %status,
589            error = %error_text,
590            "Error calling LLM API"
591        );
592        return Err(anyhow::anyhow!(
593            "Error calling LLM API: Status: {}. Error: {}",
594            status,
595            error_text
596        ));
597    }
598
599    trace!("Processing successful LLM response");
600    // Parse the response
601    let completion: LLMResponse = response.json().await?;
602    debug!(
603        "Successfully processed LLM response with {} choices",
604        completion.output.len()
605    );
606    Ok(completion)
607}
608
609/// Makes a streaming request to an LLM
610#[instrument(skip(chat_request, app_config), fields(
611    num_messages = chat_request.input.len(),
612    temperature,
613    max_tokens
614))]
615pub async fn make_streaming_llm_request(
616    chat_request: LLMRequest,
617    app_config: &ApplicationConfiguration,
618) -> anyhow::Result<Response> {
619    debug!(
620        "Preparing streaming LLM request with {} messages",
621        chat_request.input.len()
622    );
623    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
624        error!("Azure configuration missing");
625        anyhow::anyhow!("Azure configuration is missing from the application configuration")
626    })?;
627
628    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
629        error!("Chatbot configuration missing");
630        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
631    })?;
632
633    let request = AzureCompletionRequest {
634        base: chat_request,
635        stream: true,
636    };
637
638    let headers = build_llm_headers(&chatbot_config.api_key)?;
639    let api_endpoint = chatbot_config.api_endpoint.to_owned();
640    debug!(
641        "Sending streaming request to LLM endpoint: {}",
642        api_endpoint
643    );
644
645    let response = REQWEST_CLIENT
646        .post(api_endpoint)
647        .headers(headers)
648        .json(&request)
649        .send()
650        .await?;
651
652    if !response.status().is_success() {
653        let status = response.status();
654        let error_text = response.text().await?;
655        error!(
656            status = %status,
657            error = %error_text,
658            "Error calling streaming LLM API"
659        );
660        return Err(anyhow::anyhow!(
661            "Error calling LLM API: Status: {}. Error: {}",
662            status,
663            error_text
664        ));
665    }
666
667    debug!("Successfully initiated streaming response");
668    Ok(response)
669}
670
671/// Makes a non-streaming request to an LLM using application configuration
672#[instrument(skip(chat_request, app_config), fields(
673    num_messages = chat_request.input.len(),
674    temperature,
675    max_tokens
676))]
677pub async fn make_blocking_llm_request(
678    chat_request: LLMRequest,
679    app_config: &ApplicationConfiguration,
680) -> anyhow::Result<LLMResponse> {
681    debug!(
682        "Preparing blocking LLM request with {} messages",
683        chat_request.input.len()
684    );
685    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
686        error!("Azure configuration missing");
687        anyhow::anyhow!("Azure configuration is missing from the application configuration")
688    })?;
689
690    let chatbot_config = azure_config.chatbot_config.as_ref().ok_or_else(|| {
691        error!("Chatbot configuration missing");
692        anyhow::anyhow!("Chatbot configuration is missing from the Azure configuration")
693    })?;
694
695    let api_endpoint = chatbot_config.api_endpoint.to_owned();
696
697    trace!("Making LLM request to endpoint: {}", api_endpoint);
698    make_llm_request(chat_request, &api_endpoint, &chatbot_config.api_key).await
699}
700
701/// Collects all the completion choices to a string. Assumes the completion has only
702/// text message content, no tool calls or tool output.
703pub fn parse_text_completion(completion: LLMResponse) -> ChatbotResult<String> {
704    let res =
705    completion
706        .output
707        .into_iter()
708        .map(|x| match x.message_type {
709            OutputItem::Message {  content , ..} => Ok(content.get_content_text()),
710            OutputItem::Reasoning { .. } => Ok("".to_string()),
711            _ =>  Err(chatbot_err!( InvalidMessageShape, "It was assumed this LLM response contains only text, but a tool call or tool response was detected.")),
712        })
713        .collect::<ChatbotResult<Vec<String>>>()?
714        .join("");
715    if res.is_empty() {
716        return Err(chatbot_err!(
717            InvalidMessageShape,
718            "No content returned from LLM"
719        ));
720    };
721    Ok(res)
722}
723
724pub fn get_params_for_model(
725    model: &ChatbotConfigurationModel,
726    configuration: &ChatbotConfiguration,
727) -> LLMRequestParams {
728    if model.model.as_str() == "gpt-5.2-chat" {
729        return LLMRequestParams::GPTThinking(ThinkingParams {
730            reasoning: Some(Reasoning {
731                effort: ReasoningEffortLevel::Medium,
732                summary: Some(SummaryType::Detailed),
733            }),
734        });
735    }
736    match model.model_type {
737        ModelType::GPTNonThinking => LLMRequestParams::GPTNonThinking(NonThinkingParams {
738            temperature: Some(configuration.temperature),
739            top_p: Some(configuration.top_p),
740            frequency_penalty: Some(configuration.frequency_penalty),
741            presence_penalty: Some(configuration.presence_penalty),
742        }),
743        ModelType::GPTHardThinking => {
744            // make sure the effort value is valid for the model type
745            let effort = if configuration.reasoning_effort == ReasoningEffortLevel::Minimal {
746                ReasoningEffortLevel::Low
747            } else {
748                configuration.reasoning_effort
749            };
750            LLMRequestParams::GPTThinking(ThinkingParams {
751                reasoning: Some(Reasoning {
752                    effort,
753                    summary: Some(SummaryType::Detailed),
754                }),
755            })
756        }
757        ModelType::GPTThinking => {
758            // make sure the effort value is valid for the model type
759            let effort = if configuration.reasoning_effort == ReasoningEffortLevel::None {
760                ReasoningEffortLevel::Minimal
761            } else if configuration.reasoning_effort == ReasoningEffortLevel::Xhigh {
762                ReasoningEffortLevel::High
763            } else {
764                configuration.reasoning_effort
765            };
766            LLMRequestParams::GPTThinking(ThinkingParams {
767                reasoning: Some(Reasoning {
768                    effort,
769                    summary: Some(SummaryType::Detailed),
770                }),
771            })
772        }
773        ModelType::Mistral => LLMRequestParams::Mistral(MistralParams { test: true }),
774    }
775}
776
777/// Checks if the model_type is a thinking model type. This function defines
778/// which model types are thinking (reasoning)
779pub fn model_is_thinking(model_type: ModelType) -> bool {
780    matches!(
781        model_type,
782        ModelType::GPTHardThinking | ModelType::GPTThinking
783    )
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789
790    #[test]
791    fn test_estimate_tokens() {
792        // The real number is 4
793        assert_eq!(estimate_tokens("Hello, world!"), 3);
794        assert_eq!(estimate_tokens(""), 0);
795        // The real number is 9
796        assert_eq!(
797            estimate_tokens("This is a longer sentence with several words."),
798            11
799        );
800        // The real number is 7
801        assert_eq!(estimate_tokens("Hyvää päivää!"), 7);
802        // The real number is 9
803        assert_eq!(estimate_tokens("トークンは楽しい"), 12);
804        // The real number is 52
805        assert_eq!(
806            estimate_tokens("🙂🙃😀😃😄😁😆😅😂🤣😊😇🙂🙃😀😃😄😁😆😅😂🤣😊😇"),
807            48
808        );
809        // The real number is 18
810        assert_eq!(estimate_tokens("ฉันใช้โทเค็นทุกวัน"), 27);
811        // The real number is 17
812        assert_eq!(estimate_tokens("Жетони роблять мене щасливим"), 25);
813    }
814}