headless_lms_chatbot/
azure_search_index.rs

1use crate::prelude::*;
2use headless_lms_utils::{ApplicationConfiguration, http::REQWEST_CLIENT};
3
4const API_VERSION: &str = "2024-07-01";
5
6#[derive(Serialize, Deserialize)]
7#[serde(rename_all = "camelCase")]
8pub struct NewIndex {
9    name: String,
10    fields: Vec<Field>,
11    scoring_profiles: Vec<ScoringProfile>,
12    default_scoring_profile: Option<String>,
13    suggesters: Vec<Suggester>,
14    analyzers: Vec<Analyzer>,
15    tokenizers: Vec<serde_json::Value>,
16    token_filters: Vec<serde_json::Value>,
17    char_filters: Vec<serde_json::Value>,
18    cors_options: CorsOptions,
19    encryption_key: Option<EncryptionKey>,
20    similarity: Similarity,
21    semantic: Semantic,
22    vector_search: VectorSearch,
23}
24
25#[derive(Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct Analyzer {
28    name: String,
29    #[serde(rename = "@odata.type")]
30    odata_type: String,
31    char_filters: Vec<String>,
32    tokenizer: String,
33}
34
35#[derive(Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CorsOptions {
38    allowed_origins: Vec<String>,
39    max_age_in_seconds: i64,
40}
41
42#[derive(Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct EncryptionKey {
45    key_vault_key_name: String,
46    key_vault_key_version: String,
47    key_vault_uri: String,
48    access_credentials: AccessCredentials,
49}
50
51#[derive(Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct AccessCredentials {
54    application_id: String,
55    application_secret: String,
56}
57
58#[derive(Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct Field {
61    name: String,
62    #[serde(rename = "type")]
63    field_type: String,
64    key: Option<bool>,
65    searchable: Option<bool>,
66    filterable: Option<bool>,
67    sortable: Option<bool>,
68    facetable: Option<bool>,
69    retrievable: Option<bool>,
70    index_analyzer: Option<String>,
71    search_analyzer: Option<String>,
72    analyzer: Option<String>,
73    synonym_maps: Option<Vec<String>>,
74    dimensions: Option<i64>,
75    vector_search_profile: Option<String>,
76    stored: Option<bool>,
77    vector_encoding: Option<serde_json::Value>,
78}
79
80#[derive(Serialize, Deserialize)]
81pub struct ScoringProfile {
82    name: String,
83    text: Text,
84    functions: Vec<Function>,
85}
86
87#[derive(Serialize, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct Function {
90    #[serde(rename = "type")]
91    function_type: String,
92    boost: i64,
93    field_name: String,
94    interpolation: String,
95    distance: Distance,
96}
97
98#[derive(Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub struct Distance {
101    reference_point_parameter: String,
102    boosting_distance: i64,
103}
104
105#[derive(Serialize, Deserialize)]
106pub struct Text {
107    weights: Weights,
108}
109
110#[derive(Serialize, Deserialize)]
111#[serde(rename_all = "camelCase")]
112pub struct Weights {
113    hotel_name: i64,
114}
115
116#[derive(Serialize, Deserialize)]
117#[serde(rename_all = "camelCase")]
118pub struct Semantic {
119    default_configuration: String,
120    configurations: Vec<SemanticConfiguration>,
121}
122
123#[derive(Serialize, Deserialize)]
124#[serde(rename_all = "camelCase")]
125pub struct SemanticConfiguration {
126    name: String,
127    prioritized_fields: SemanticConfigurationPrioritizedFields,
128}
129
130#[derive(Serialize, Deserialize)]
131#[serde(rename_all = "camelCase")]
132pub struct SemanticConfigurationPrioritizedFields {
133    title_field: FieldDescriptor,
134    prioritized_content_fields: Vec<FieldDescriptor>,
135    prioritized_keywords_fields: Vec<FieldDescriptor>,
136}
137
138#[derive(Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct FieldDescriptor {
141    field_name: String,
142}
143
144#[derive(Serialize, Deserialize)]
145pub struct Similarity {
146    #[serde(rename = "@odata.type")]
147    odata_type: String,
148    b: Option<f64>,
149    k1: Option<f64>,
150}
151
152#[derive(Serialize, Deserialize)]
153#[serde(rename_all = "camelCase")]
154pub struct Suggester {
155    name: String,
156    search_mode: String,
157    source_fields: Vec<String>,
158}
159
160#[derive(Serialize, Deserialize)]
161pub struct VectorSearch {
162    profiles: Vec<Profile>,
163    algorithms: Vec<Algorithm>,
164    compressions: Vec<Compression>,
165    vectorizers: Vec<Vectorizer>,
166}
167
168#[derive(Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct Vectorizer {
171    name: String,
172    kind: String,
173    #[serde(rename = "azureOpenAIParameters")]
174    azure_open_ai_parameters: AzureOpenAiParameters,
175    custom_web_api_parameters: Option<serde_json::Value>,
176}
177
178#[derive(Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct AzureOpenAiParameters {
181    resource_uri: String,
182    deployment_id: String,
183    api_key: String,
184    model_name: String,
185    auth_identity: Option<serde_json::Value>,
186}
187
188#[derive(Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct Algorithm {
191    name: String,
192    kind: String,
193    hnsw_parameters: Option<HnswParameters>,
194    exhaustive_knn_parameters: Option<ExhaustiveKnnParameters>,
195}
196
197#[derive(Serialize, Deserialize)]
198pub struct ExhaustiveKnnParameters {
199    metric: String,
200}
201
202#[derive(Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct HnswParameters {
205    m: i64,
206    metric: String,
207    ef_construction: i64,
208    ef_search: i64,
209}
210
211#[derive(Serialize, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct Compression {
214    name: String,
215    kind: String,
216    scalar_quantization_parameters: Option<ScalarQuantizationParameters>,
217    rerank_with_original_vectors: bool,
218    default_oversampling: i64,
219}
220
221#[derive(Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct ScalarQuantizationParameters {
224    quantized_data_type: String,
225}
226
227#[derive(Serialize, Deserialize)]
228pub struct Profile {
229    name: String,
230    algorithm: String,
231    compression: Option<String>,
232    vectorizer: Option<String>,
233}
234
235pub async fn does_search_index_exist(
236    index_name: &str,
237    app_config: &ApplicationConfiguration,
238) -> anyhow::Result<bool> {
239    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
240        anyhow::anyhow!("Azure configuration is missing from the application configuration")
241    })?;
242
243    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
244        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
245    })?;
246
247    let mut url = search_config.search_endpoint.clone();
248    url.set_path(&format!("indexes('{}')", index_name));
249    url.set_query(Some(&format!("api-version={}", API_VERSION)));
250
251    let response = REQWEST_CLIENT
252        .get(url)
253        .header("Content-Type", "application/json")
254        .header("api-key", search_config.search_api_key.clone())
255        .send()
256        .await?;
257
258    if response.status().is_success() {
259        Ok(true)
260    } else if response.status() == 404 {
261        Ok(false)
262    } else {
263        let status = response.status();
264        let error_text = response.text().await?;
265        Err(anyhow::anyhow!(
266            "Error checking if index exists. Status: {}. Error: {}",
267            status,
268            error_text
269        ))
270    }
271}
272
273pub async fn create_search_index(
274    index_name: String,
275    app_config: &ApplicationConfiguration,
276) -> anyhow::Result<()> {
277    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
278        anyhow::anyhow!("Azure configuration is missing from the application configuration")
279    })?;
280
281    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
282        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
283    })?;
284
285    let fields = vec![
286        Field {
287            name: "chunk_id".to_string(),
288            field_type: "Edm.String".to_string(),
289            key: Some(true),
290            searchable: Some(true),
291            filterable: Some(true),
292            retrievable: Some(true),
293            stored: Some(true),
294            sortable: Some(true),
295            facetable: Some(true),
296            analyzer: Some("keyword".to_string()),
297            index_analyzer: None,
298            search_analyzer: None,
299            synonym_maps: Some(vec![]),
300            dimensions: None,
301            vector_search_profile: None,
302            vector_encoding: None,
303        },
304        Field {
305            name: "language".to_string(),
306            field_type: "Edm.String".to_string(),
307            key: Some(false),
308            searchable: Some(true),
309            filterable: Some(true),
310            retrievable: Some(true),
311            stored: Some(true),
312            sortable: Some(false),
313            facetable: Some(false),
314            analyzer: None,
315            index_analyzer: None,
316            search_analyzer: None,
317            synonym_maps: Some(vec![]),
318            dimensions: None,
319            vector_search_profile: None,
320            vector_encoding: None,
321        },
322        Field {
323            name: "parent_id".to_string(),
324            field_type: "Edm.String".to_string(),
325            key: Some(false),
326            searchable: Some(true),
327            filterable: Some(true),
328            retrievable: Some(true),
329            stored: Some(true),
330            sortable: Some(true),
331            facetable: Some(true),
332            analyzer: None,
333            index_analyzer: None,
334            search_analyzer: None,
335            synonym_maps: Some(vec![]),
336            dimensions: None,
337            vector_search_profile: None,
338            vector_encoding: None,
339        },
340        Field {
341            name: "chunk".to_string(),
342            field_type: "Edm.String".to_string(),
343            key: Some(false),
344            searchable: Some(true),
345            filterable: Some(false),
346            retrievable: Some(true),
347            stored: Some(true),
348            sortable: Some(false),
349            facetable: Some(false),
350            analyzer: None,
351            index_analyzer: None,
352            search_analyzer: None,
353            synonym_maps: Some(vec![]),
354            dimensions: None,
355            vector_search_profile: None,
356            vector_encoding: None,
357        },
358        Field {
359            name: "title".to_string(),
360            field_type: "Edm.String".to_string(),
361            key: Some(false),
362            searchable: Some(true),
363            filterable: Some(true),
364            retrievable: Some(true),
365            stored: Some(true),
366            sortable: Some(false),
367            facetable: Some(false),
368            analyzer: None,
369            index_analyzer: None,
370            search_analyzer: None,
371            synonym_maps: Some(vec![]),
372            dimensions: None,
373            vector_search_profile: None,
374            vector_encoding: None,
375        },
376        Field {
377            name: "url".to_string(),
378            field_type: "Edm.String".to_string(),
379            key: Some(false),
380            searchable: Some(false),
381            filterable: Some(true),
382            retrievable: Some(true),
383            stored: Some(true),
384            sortable: Some(false),
385            facetable: Some(false),
386            analyzer: None,
387            index_analyzer: None,
388            search_analyzer: None,
389            synonym_maps: Some(vec![]),
390            dimensions: None,
391            vector_search_profile: None,
392            vector_encoding: None,
393        },
394        Field {
395            name: "course_id".to_string(),
396            field_type: "Edm.String".to_string(),
397            key: Some(false),
398            searchable: Some(false),
399            filterable: Some(true),
400            retrievable: Some(true),
401            stored: Some(true),
402            sortable: Some(false),
403            facetable: Some(false),
404            analyzer: None,
405            index_analyzer: None,
406            search_analyzer: None,
407            synonym_maps: Some(vec![]),
408            dimensions: None,
409            vector_search_profile: None,
410            vector_encoding: None,
411        },
412        Field {
413            name: "text_vector".to_string(),
414            field_type: "Collection(Edm.Single)".to_string(),
415            key: Some(false),
416            searchable: Some(true),
417            filterable: Some(false),
418            retrievable: Some(true),
419            stored: Some(true),
420            sortable: Some(false),
421            facetable: Some(false),
422            analyzer: None,
423            index_analyzer: None,
424            search_analyzer: None,
425            synonym_maps: Some(vec![]),
426            dimensions: Some(1536),
427            vector_search_profile: Some(format!("{}-azureOpenAi-text-profile", index_name)),
428            vector_encoding: None,
429        },
430    ];
431
432    let index = NewIndex {
433        name: index_name.clone(),
434        fields,
435        scoring_profiles: vec![],
436        default_scoring_profile: None,
437        suggesters: vec![],
438        analyzers: vec![],
439        tokenizers: vec![],
440        token_filters: vec![],
441        char_filters: vec![],
442        cors_options: CorsOptions {
443            allowed_origins: vec!["*".to_string()],
444            max_age_in_seconds: 300,
445        },
446        encryption_key: None,
447        similarity: Similarity {
448            odata_type: "#Microsoft.Azure.Search.BM25Similarity".to_string(),
449            b: None,
450            k1: None,
451        },
452        semantic: Semantic {
453            default_configuration: format!("{}-semantic-configuration", index_name),
454            configurations: vec![SemanticConfiguration {
455                name: format!("{}-semantic-configuration", index_name),
456                prioritized_fields: SemanticConfigurationPrioritizedFields {
457                    title_field: FieldDescriptor {
458                        field_name: "title".to_string(),
459                    },
460                    prioritized_content_fields: vec![FieldDescriptor {
461                        field_name: "chunk".to_string(),
462                    }],
463                    prioritized_keywords_fields: vec![],
464                },
465            }],
466        },
467        vector_search: VectorSearch {
468            profiles: vec![Profile {
469                name: format!("{}-azureOpenAi-text-profile", index_name),
470                algorithm: format!("{}-algorithm", index_name),
471                vectorizer: Some(format!("{}-azureOpenAi-text-vectorizer", index_name)),
472                compression: None,
473            }],
474            algorithms: vec![Algorithm {
475                name: format!("{}-algorithm", index_name),
476                kind: "hnsw".to_string(),
477                hnsw_parameters: Some(HnswParameters {
478                    m: 4,
479                    metric: "cosine".to_string(),
480                    ef_construction: 400,
481                    ef_search: 500,
482                }),
483                exhaustive_knn_parameters: None,
484            }],
485            vectorizers: vec![Vectorizer {
486                name: format!("{}-azureOpenAi-text-vectorizer", index_name),
487                kind: "azureOpenAI".to_string(),
488                azure_open_ai_parameters: AzureOpenAiParameters {
489                    resource_uri: search_config.vectorizer_resource_uri.clone(),
490                    deployment_id: search_config.vectorizer_deployment_id.clone(),
491                    api_key: search_config.vectorizer_api_key.clone(),
492                    model_name: search_config.vectorizer_model_name.clone(),
493                    auth_identity: None,
494                },
495                custom_web_api_parameters: None,
496            }],
497            compressions: vec![],
498        },
499    };
500
501    let index_json = serde_json::to_string(&index)?;
502
503    let mut url = search_config.search_endpoint.clone();
504    url.set_path("/indexes");
505    url.set_query(Some(&format!("api-version={}", API_VERSION)));
506
507    let response = REQWEST_CLIENT
508        .post(url)
509        .header("Content-Type", "application/json")
510        .header("api-key", search_config.search_api_key.clone())
511        .body(index_json)
512        .send()
513        .await?;
514
515    // Check for a successful response
516    if response.status().is_success() {
517        println!("Index created successfully: {}", index_name);
518        Ok(())
519    } else {
520        let status = response.status();
521        let error_text = response.text().await?;
522        Err(anyhow::anyhow!(
523            "Failed to create index. Status: {}. Error: {}",
524            status,
525            error_text
526        ))
527    }
528}
529
530#[derive(Serialize, Deserialize)]
531pub struct IndexAction<T> {
532    #[serde(rename = "@search.action")]
533    pub search_action: String,
534    pub document: T,
535}
536
537#[derive(Serialize, Deserialize)]
538pub struct IndexBatch<T> {
539    pub value: Vec<IndexAction<T>>,
540}
541
542pub async fn add_documents_to_index<T>(
543    index_name: &str,
544    documents: Vec<T>,
545    app_config: &ApplicationConfiguration,
546) -> anyhow::Result<()>
547where
548    T: Serialize,
549{
550    let azure_config = app_config.azure_configuration.as_ref().ok_or_else(|| {
551        anyhow::anyhow!("Azure configuration is missing from the application configuration")
552    })?;
553
554    let search_config = azure_config.search_config.as_ref().ok_or_else(|| {
555        anyhow::anyhow!("Azure search configuration is missing from the Azure configuration")
556    })?;
557
558    let mut url = search_config.search_endpoint.clone();
559    url.set_path(&format!("indexes('{}')/docs/index", index_name));
560    url.set_query(Some(&format!("api-version={}", API_VERSION)));
561
562    let index_actions: Vec<IndexAction<String>> = documents
563        .into_iter()
564        .map(|doc| IndexAction {
565            search_action: "upload".to_string(),
566            document: serde_json::to_string(&doc).unwrap(),
567        })
568        .collect();
569
570    let batch = IndexBatch {
571        value: index_actions,
572    };
573
574    let batch_json = serde_json::to_string(&batch)?;
575
576    let response = REQWEST_CLIENT
577        .post(url)
578        .header("Content-Type", "application/json")
579        .header("api-key", search_config.search_api_key.clone())
580        .body(batch_json)
581        .send()
582        .await?;
583
584    if response.status().is_success() {
585        println!("Documents added successfully to index: {}", index_name);
586        Ok(())
587    } else {
588        let status = response.status();
589        let error_text = response.text().await?;
590        Err(anyhow::anyhow!(
591            "Failed to add documents to index. Status: {}. Error: {}",
592            status,
593            error_text
594        ))
595    }
596}