headless_lms_chatbot/chatbot_tools/
mod.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use sqlx::PgConnection;
5
6use crate::{
7 azure_chatbot::ChatbotUserContext,
8 chatbot_tools::course_progress::CourseProgressTool,
9 prelude::{BackendError, ChatbotError, ChatbotErrorType, ChatbotResult},
10};
11
12pub mod course_progress;
13
14pub trait ChatbotTool {
15 type State;
16 type Arguments: Serialize;
17
18 fn parse_arguments(args_string: String) -> ChatbotResult<Self::Arguments>;
20
21 fn from_db_and_arguments(
23 conn: &mut PgConnection,
24 arguments: Self::Arguments,
25 user_context: &ChatbotUserContext,
26 ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
27 where
28 Self: Sized;
29
30 fn output(&self) -> String;
32
33 fn output_description_instructions(&self) -> Option<String>;
36
37 fn get_tool_output(&self) -> String {
39 let output = self.output();
40 let instructions = self.output_description_instructions();
41
42 if let Some(i) = instructions {
43 format!(
44 "Result: [output]{output}[/output]\n\nInstructions for describing the output: [instructions]{i}[/instructions]"
45 )
46 } else {
47 output
48 }
49 }
50
51 fn get_arguments(&self) -> &Self::Arguments;
53
54 fn get_tool_definition() -> AzureLLMToolDefinition;
57
58 fn new(
60 conn: &mut PgConnection,
61 args_string: String,
62 user_context: &ChatbotUserContext,
63 ) -> impl std::future::Future<Output = ChatbotResult<Self>> + Send
64 where
65 Self: Sized,
66 {
67 async {
68 let parsed = Self::parse_arguments(args_string)?;
69 Self::from_db_and_arguments(conn, parsed, user_context).await
70 }
71 }
72}
73
74pub struct ToolProperties<S, A: Serialize> {
75 state: S,
76 arguments: A,
77}
78
79#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
82pub struct AzureLLMToolDefinition {
83 #[serde(rename = "type")]
84 pub tool_type: LLMToolType,
85 pub function: LLMTool,
86}
87#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
89pub struct LLMTool {
90 pub name: String,
91 pub description: String,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub parameters: Option<LLMToolParams>,
94}
95
96#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
98pub struct LLMToolParams {
99 #[serde(rename = "type")]
100 pub tool_type: LLMToolParamType,
101 pub properties: HashMap<String, LLMToolParamProperties>,
102 pub required: Vec<String>,
103}
104
105#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
106pub struct LLMToolParamProperties {
107 #[serde(rename = "type")]
108 pub param_type: String,
109 pub description: String,
110}
111
112#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
113#[serde(rename_all = "snake_case")]
114pub enum LLMToolParamType {
115 Object,
116}
117
118#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
119#[serde(rename_all = "snake_case")]
120pub enum LLMToolType {
121 Function,
122}
123
124pub fn get_chatbot_tool_definitions() -> Vec<AzureLLMToolDefinition> {
126 vec![CourseProgressTool::get_tool_definition()]
127}
128
129pub async fn get_chatbot_tool(
132 conn: &mut PgConnection,
133 fn_name: &str,
134 _fn_args: &str, user_context: &ChatbotUserContext,
136) -> ChatbotResult<impl ChatbotTool> {
137 let tool = match fn_name {
138 "course_progress" => CourseProgressTool::new(conn, "".to_string(), user_context).await?,
139 _ => {
140 return Err(ChatbotError::new(
141 ChatbotErrorType::InvalidToolName,
142 "Incorrect or unknown function name".to_string(),
143 None,
144 ));
145 }
146 };
147 Result::Ok(tool)
148}