headless_lms_chatbot/chatbot_tools/
mod.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use sqlx::PgConnection;
6
7use crate::{
8 azure_chatbot::ChatbotUserContext,
9 chatbot_error::chatbot_err,
10 chatbot_tools::{
11 custom_tools::course_progress::CourseProgressTool,
12 provider_tools::azure_ai_search::AzureAISearchToolDefinition,
13 },
14 prelude::{BackendError, ChatbotError, ChatbotErrorType, ChatbotResult},
15};
16
17pub mod custom_tools;
18pub mod provider_tools;
19
20pub trait ChatbotTool {
21 type State;
22 type Arguments: Serialize;
23
24 fn parse_arguments(args_string: String) -> ChatbotResult<Self::Arguments>;
26
27 fn from_db_and_arguments(
29 conn: &mut PgConnection,
30 arguments: Self::Arguments,
31 user_context: &ChatbotUserContext,
32 ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
33 where
34 Self: Sized;
35
36 fn output(&self) -> String;
38
39 fn output_description_instructions(&self) -> Option<String>;
42
43 fn get_tool_output(&self) -> String {
45 let output = self.output();
46 let instructions = self.output_description_instructions();
47
48 if let Some(i) = instructions {
49 format!(
50 "Result: [output]{output}[/output]\n\nInstructions for describing the output: [instructions]{i}[/instructions]"
51 )
52 } else {
53 output
54 }
55 }
56
57 fn get_arguments(&self) -> &Self::Arguments;
59
60 fn get_tool_definition() -> AzureLLMFunctionToolDefinition;
63
64 fn new(
66 conn: &mut PgConnection,
67 args_string: String,
68 user_context: &ChatbotUserContext,
69 ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
70 where
71 Self: Sized,
72 {
73 async {
74 let parsed = Self::parse_arguments(args_string)?;
75 Self::from_db_and_arguments(conn, parsed, user_context).await
76 }
77 }
78}
79
80pub struct ToolProperties<S, A: Serialize> {
81 state: S,
82 arguments: A,
83}
84
85#[derive(Clone, Debug, Deserialize, Serialize)]
86#[serde(untagged)]
87pub enum AzureLLMToolDefinition {
88 Function(AzureLLMFunctionToolDefinition),
89 Search(AzureAISearchToolDefinition),
90}
91
92#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
95pub struct AzureLLMFunctionToolDefinition {
96 #[serde(rename = "type")]
97 pub tool_type: LLMToolType,
98 pub name: String,
99 pub description: String,
100 pub parameters: LLMToolParams,
101 pub strict: bool,
103}
104
105#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct LLMToolParams {
109 #[serde(rename = "type")]
110 pub tool_type: LLMToolParamType,
111 pub properties: HashMap<String, LLMToolParamProperties>,
112 pub required: Vec<String>,
113 pub additional_properties: bool,
115}
116
117#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
118pub struct LLMToolParamProperties {
119 #[serde(rename = "type")]
120 pub param_type: String,
121 pub description: String,
122}
123
124#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
125#[serde(rename_all = "snake_case")]
126pub enum LLMToolParamType {
127 Object,
128}
129
130#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum LLMToolType {
133 Function,
134}
135
136pub fn get_chatbot_tool_definitions() -> Vec<AzureLLMToolDefinition> {
138 vec![AzureLLMToolDefinition::Function(
139 CourseProgressTool::get_tool_definition(),
140 )]
141}
142
143pub async fn get_chatbot_tool(
146 conn: &mut PgConnection,
147 fn_name: &str,
148 _fn_args: &Value, user_context: &ChatbotUserContext,
150) -> ChatbotResult<impl ChatbotTool> {
151 let tool = match fn_name {
152 "course_progress" => CourseProgressTool::new(conn, "".to_string(), user_context).await?,
153 _ => {
154 return Err(chatbot_err!(
155 InvalidToolName,
156 "Incorrect or unknown function name".to_string()
157 ));
158 }
159 };
160 Result::Ok(tool)
161}