Skip to main content

headless_lms_chatbot/chatbot_tools/
mod.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use sqlx::PgConnection;
6
7use crate::{
8    azure_chatbot::ChatbotUserContext,
9    chatbot_error::chatbot_err,
10    chatbot_tools::{
11        custom_tools::course_progress::CourseProgressTool,
12        provider_tools::azure_ai_search::AzureAISearchToolDefinition,
13    },
14    prelude::{BackendError, ChatbotError, ChatbotErrorType, ChatbotResult},
15};
16
17pub mod custom_tools;
18pub mod provider_tools;
19
20pub trait ChatbotTool {
21    type State;
22    type Arguments: Serialize;
23
24    /// Parse the LLM-generated function arguments and clean them
25    fn parse_arguments(args_string: String) -> ChatbotResult<Self::Arguments>;
26
27    /// Create a new instance after parsing arguments
28    fn from_db_and_arguments(
29        conn: &mut PgConnection,
30        arguments: Self::Arguments,
31        user_context: &ChatbotUserContext,
32    ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
33    where
34        Self: Sized;
35
36    /// Output the result of the tool call in LLM-readable form
37    fn output(&self) -> String;
38
39    /// Additional instructions for the LLM on how to describe and
40    /// communicate the tool output. Just-in-time prompt.
41    fn output_description_instructions(&self) -> Option<String>;
42
43    /// Get and format tool output and instructions for LLM
44    fn get_tool_output(&self) -> String {
45        let output = self.output();
46        let instructions = self.output_description_instructions();
47
48        if let Some(i) = instructions {
49            format!(
50                "Result: [output]{output}[/output]\n\nInstructions for describing the output: [instructions]{i}[/instructions]"
51            )
52        } else {
53            output
54        }
55    }
56
57    /// Get parsed arguments
58    fn get_arguments(&self) -> &Self::Arguments;
59
60    /// Get a AzureLLMToolDefinition struct that represents this tool.
61    /// The definition is sent to the LLM as part of a chat request.
62    fn get_tool_definition() -> AzureLLMFunctionToolDefinition;
63
64    /// Create a new instance from connection, args and context
65    fn new(
66        conn: &mut PgConnection,
67        args_string: String,
68        user_context: &ChatbotUserContext,
69    ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
70    where
71        Self: Sized,
72    {
73        async {
74            let parsed = Self::parse_arguments(args_string)?;
75            Self::from_db_and_arguments(conn, parsed, user_context).await
76        }
77    }
78}
79
80pub struct ToolProperties<S, A: Serialize> {
81    state: S,
82    arguments: A,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86#[serde(untagged)]
87pub enum AzureLLMToolDefinition {
88    Function(AzureLLMFunctionToolDefinition),
89    Search(AzureAISearchToolDefinition),
90}
91
92/// A tool definition that is formatted for Azure.
93/// Defines a tool (function) that the LLM can call.
94#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
95pub struct AzureLLMFunctionToolDefinition {
96    #[serde(rename = "type")]
97    pub tool_type: LLMToolType,
98    pub name: String,
99    pub description: String,
100    pub parameters: LLMToolParams,
101    /// Ensures that the LLM calls the tool with the correct params. Should be `true`
102    pub strict: bool,
103}
104
105/// Parameters that a chatbot tool accepts in an AzureLLMToolDefinition
106#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct LLMToolParams {
109    #[serde(rename = "type")]
110    pub tool_type: LLMToolParamType,
111    pub properties: HashMap<String, LLMToolParamProperties>,
112    pub required: Vec<String>,
113    /// required to be false
114    pub additional_properties: bool,
115}
116
117#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
118pub struct LLMToolParamProperties {
119    #[serde(rename = "type")]
120    pub param_type: String,
121    pub description: String,
122}
123
124#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
125#[serde(rename_all = "snake_case")]
126pub enum LLMToolParamType {
127    Object,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum LLMToolType {
133    Function,
134}
135
136/// Get a vec of AzureLLMToolDefinitions for all available chatbot tools
137pub fn get_chatbot_tool_definitions() -> Vec<AzureLLMToolDefinition> {
138    vec![AzureLLMToolDefinition::Function(
139        CourseProgressTool::get_tool_definition(),
140    )]
141}
142
143/// Create a chatbot tool with LLM-provided arguments by matching the tool call
144/// made by the LLM. User context and db connection are needed for some tools.
145pub async fn get_chatbot_tool(
146    conn: &mut PgConnection,
147    fn_name: &str,
148    _fn_args: &Value, // used in the future in other tool
149    user_context: &ChatbotUserContext,
150) -> ChatbotResult<impl ChatbotTool> {
151    let tool = match fn_name {
152        "course_progress" => CourseProgressTool::new(conn, "".to_string(), user_context).await?,
153        _ => {
154            return Err(chatbot_err!(
155                InvalidToolName,
156                "Incorrect or unknown function name".to_string()
157            ));
158        }
159    };
160    Result::Ok(tool)
161}