headless_lms_chatbot/chatbot_tools/
mod.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use sqlx::PgConnection;
5
6use crate::{
7    azure_chatbot::ChatbotUserContext,
8    chatbot_tools::course_progress::CourseProgressTool,
9    prelude::{BackendError, ChatbotError, ChatbotErrorType, ChatbotResult},
10};
11
12pub mod course_progress;
13
14pub trait ChatbotTool {
15    type State;
16    type Arguments: Serialize;
17
18    /// Parse the LLM-generated function arguments and clean them
19    fn parse_arguments(args_string: String) -> ChatbotResult<Self::Arguments>;
20
21    /// Create a new instance after parsing arguments
22    fn from_db_and_arguments(
23        conn: &mut PgConnection,
24        arguments: Self::Arguments,
25        user_context: &ChatbotUserContext,
26    ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
27    where
28        Self: Sized;
29
30    /// Output the result of the tool call in LLM-readable form
31    fn output(&self) -> String;
32
33    /// Additional instructions for the LLM on how to describe and
34    /// communicate the tool output. Just-in-time prompt.
35    fn output_description_instructions(&self) -> Option<String>;
36
37    /// Get and format tool output and instructions for LLM
38    fn get_tool_output(&self) -> String {
39        let output = self.output();
40        let instructions = self.output_description_instructions();
41
42        if let Some(i) = instructions {
43            format!(
44                "Result: [output]{output}[/output]\n\nInstructions for describing the output: [instructions]{i}[/instructions]"
45            )
46        } else {
47            output
48        }
49    }
50
51    /// Get parsed arguments
52    fn get_arguments(&self) -> &Self::Arguments;
53
54    /// Get a AzureLLMToolDefinition struct that represents this tool.
55    /// The definition is sent to the LLM as part of a chat request.
56    fn get_tool_definition() -> AzureLLMToolDefinition;
57
58    /// Create a new instance from connection, args and context
59    fn new(
60        conn: &mut PgConnection,
61        args_string: String,
62        user_context: &ChatbotUserContext,
63    ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
64    where
65        Self: Sized,
66    {
67        async {
68            let parsed = Self::parse_arguments(args_string)?;
69            Self::from_db_and_arguments(conn, parsed, user_context).await
70        }
71    }
72}
73
74pub struct ToolProperties<S, A: Serialize> {
75    state: S,
76    arguments: A,
77}
78
79/// A tool definition that is formatted for Azure.
80/// Defines a tool (function) that the LLM can call.
81#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
82pub struct AzureLLMToolDefinition {
83    #[serde(rename = "type")]
84    pub tool_type: LLMToolType,
85    pub function: LLMTool,
86}
87/// Content of an AzureLLMToolDefinition
88#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
89pub struct LLMTool {
90    pub name: String,
91    pub description: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub parameters: Option<LLMToolParams>,
94}
95
96/// Parameters that a chatbot tool accepts in an AzureLLMToolDefinition
97#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
98pub struct LLMToolParams {
99    #[serde(rename = "type")]
100    pub tool_type: LLMToolParamType,
101    pub properties: HashMap<String, LLMToolParamProperties>,
102    pub required: Vec<String>,
103}
104
105#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
106pub struct LLMToolParamProperties {
107    #[serde(rename = "type")]
108    pub param_type: String,
109    pub description: String,
110}
111
112#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum LLMToolParamType {
115    Object,
116}
117
118#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
119#[serde(rename_all = "snake_case")]
120pub enum LLMToolType {
121    Function,
122}
123
124/// Get a vec of AzureLLMToolDefinitions for all available chatbot tools
125pub fn get_chatbot_tool_definitions() -> Vec<AzureLLMToolDefinition> {
126    vec![CourseProgressTool::get_tool_definition()]
127}
128
129/// Create a chatbot tool with LLM-provided arguments by matching the tool call
130/// made by the LLM. User context and db connection are needed for some tools.
131pub async fn get_chatbot_tool(
132    conn: &mut PgConnection,
133    fn_name: &str,
134    _fn_args: &str, // used in the future in other tool
135    user_context: &ChatbotUserContext,
136) -> ChatbotResult<impl ChatbotTool> {
137    let tool = match fn_name {
138        "course_progress" => CourseProgressTool::new(conn, "".to_string(), user_context).await?,
139        _ => {
140            return Err(ChatbotError::new(
141                ChatbotErrorType::InvalidToolName,
142                "Incorrect or unknown function name".to_string(),
143                None,
144            ));
145        }
146    };
147    Result::Ok(tool)
148}