icu_segmenter/provider/
lstm.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5//! Data provider struct definitions for the lstm
6
7// Provider structs must be stable
8#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)]
9
10use icu_provider::prelude::*;
11use potential_utf::PotentialUtf8;
12use zerovec::{ZeroMap, ZeroVec};
13
14// We do this instead of const generics because ZeroFrom and Yokeable derives, as well as serde
15// don't support them
16macro_rules! lstm_matrix {
17    ($name:ident, $generic:literal) => {
18        /// The struct that stores a LSTM's matrix.
19        ///
20        /// <div class="stab unstable">
21        /// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
22        /// including in SemVer minor releases. While the serde representation of data structs is guaranteed
23        /// to be stable, their Rust representation might not be. Use with caution.
24        /// </div>
25        #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)]
26        #[cfg_attr(feature = "datagen", derive(serde::Serialize))]
27        pub struct $name<'data> {
28            // Invariant: dims.product() == data.len()
29            #[allow(missing_docs)]
30            pub(crate) dims: [u16; $generic],
31            #[allow(missing_docs)]
32            pub(crate) data: ZeroVec<'data, f32>,
33        }
34
35        impl<'data> $name<'data> {
36            #[cfg(any(feature = "serde", feature = "datagen"))]
37            /// Creates a LstmMatrix with the given dimensions. Fails if the dimensions don't match the data.
38            pub fn from_parts(
39                dims: [u16; $generic],
40                data: ZeroVec<'data, f32>,
41            ) -> Result<Self, DataError> {
42                if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() {
43                    Err(DataError::custom("Dimension mismatch"))
44                } else {
45                    Ok(Self { dims, data })
46                }
47            }
48
49            #[doc(hidden)] // databake
50            pub const fn from_parts_unchecked(
51                dims: [u16; $generic],
52                data: ZeroVec<'data, f32>,
53            ) -> Self {
54                Self { dims, data }
55            }
56        }
57
58        #[cfg(feature = "serde")]
59        impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> {
60            fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
61            where
62                S: serde::de::Deserializer<'de>,
63            {
64                #[derive(serde::Deserialize)]
65                struct Raw<'data> {
66                    dims: [u16; $generic],
67                    #[serde(borrow)]
68                    data: ZeroVec<'data, f32>,
69                }
70
71                let raw = Raw::deserialize(deserializer)?;
72
73                use serde::de::Error;
74                Self::from_parts(raw.dims, raw.data)
75                    .map_err(|_| S::Error::custom("Dimension mismatch"))
76            }
77        }
78
79        #[cfg(feature = "datagen")]
80        impl databake::Bake for $name<'_> {
81            fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
82                let dims = self.dims.bake(env);
83                let data = self.data.bake(env);
84                databake::quote! {
85                    icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data)
86                }
87            }
88        }
89
90        #[cfg(feature = "datagen")]
91        impl databake::BakeSize for $name<'_> {
92            fn borrows_size(&self) -> usize {
93                self.data.borrows_size()
94            }
95        }
96    };
97}
98
99lstm_matrix!(LstmMatrix1, 1);
100lstm_matrix!(LstmMatrix2, 2);
101lstm_matrix!(LstmMatrix3, 3);
102
103#[derive(PartialEq, Debug, Clone, Copy)]
104#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake))]
105#[cfg_attr(feature = "datagen", databake(path = icu_segmenter::provider))]
106#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
107/// The type of LSTM model
108///
109/// <div class="stab unstable">
110/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
111/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
112/// to be stable, their Rust representation might not be. Use with caution.
113/// </div>
114pub enum ModelType {
115    /// A model working on code points
116    Codepoints,
117    /// A model working on grapheme clusters
118    GraphemeClusters,
119}
120
121/// The struct that stores a LSTM model.
122///
123/// <div class="stab unstable">
124/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
125/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
126/// to be stable, their Rust representation might not be. Use with caution.
127/// </div>
128#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
129#[cfg_attr(feature = "datagen", derive(serde::Serialize))]
130#[yoke(prove_covariance_manually)]
131pub struct LstmDataFloat32<'data> {
132    /// Type of the model
133    pub(crate) model: ModelType,
134    /// The grapheme cluster dictionary used to train the model
135    pub(crate) dic: ZeroMap<'data, PotentialUtf8, u16>,
136    /// The embedding layer. Shape (dic.len + 1, e)
137    pub(crate) embedding: LstmMatrix2<'data>,
138    /// The forward layer's first matrix. Shape (h, 4, e)
139    pub(crate) fw_w: LstmMatrix3<'data>,
140    /// The forward layer's second matrix. Shape (h, 4, h)
141    pub(crate) fw_u: LstmMatrix3<'data>,
142    /// The forward layer's bias. Shape (h, 4)
143    pub(crate) fw_b: LstmMatrix2<'data>,
144    /// The backward layer's first matrix. Shape (h, 4, e)
145    pub(crate) bw_w: LstmMatrix3<'data>,
146    /// The backward layer's second matrix. Shape (h, 4, h)
147    pub(crate) bw_u: LstmMatrix3<'data>,
148    /// The backward layer's bias. Shape (h, 4)
149    pub(crate) bw_b: LstmMatrix2<'data>,
150    /// The output layer's weights. Shape (2, 4, h)
151    pub(crate) time_w: LstmMatrix3<'data>,
152    /// The output layer's bias. Shape (4)
153    pub(crate) time_b: LstmMatrix1<'data>,
154}
155
156impl<'data> LstmDataFloat32<'data> {
157    #[doc(hidden)] // databake
158    #[allow(clippy::too_many_arguments)] // constructor
159    pub const fn from_parts_unchecked(
160        model: ModelType,
161        dic: ZeroMap<'data, PotentialUtf8, u16>,
162        embedding: LstmMatrix2<'data>,
163        fw_w: LstmMatrix3<'data>,
164        fw_u: LstmMatrix3<'data>,
165        fw_b: LstmMatrix2<'data>,
166        bw_w: LstmMatrix3<'data>,
167        bw_u: LstmMatrix3<'data>,
168        bw_b: LstmMatrix2<'data>,
169        time_w: LstmMatrix3<'data>,
170        time_b: LstmMatrix1<'data>,
171    ) -> Self {
172        Self {
173            model,
174            dic,
175            embedding,
176            fw_w,
177            fw_u,
178            fw_b,
179            bw_w,
180            bw_u,
181            bw_b,
182            time_w,
183            time_b,
184        }
185    }
186
187    #[cfg(any(feature = "serde", feature = "datagen"))]
188    /// Creates a LstmDataFloat32 with the given data. Fails if the matrix dimensions are inconsistent.
189    #[allow(clippy::too_many_arguments)] // constructor
190    pub fn try_from_parts(
191        model: ModelType,
192        dic: ZeroMap<'data, PotentialUtf8, u16>,
193        embedding: LstmMatrix2<'data>,
194        fw_w: LstmMatrix3<'data>,
195        fw_u: LstmMatrix3<'data>,
196        fw_b: LstmMatrix2<'data>,
197        bw_w: LstmMatrix3<'data>,
198        bw_u: LstmMatrix3<'data>,
199        bw_b: LstmMatrix2<'data>,
200        time_w: LstmMatrix3<'data>,
201        time_b: LstmMatrix1<'data>,
202    ) -> Result<Self, DataError> {
203        let dic_len = u16::try_from(dic.len())
204            .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?;
205
206        let num_classes = embedding.dims[0];
207        let embedd_dim = embedding.dims[1];
208        let hunits = fw_u.dims[2];
209        if num_classes - 1 != dic_len
210            || fw_w.dims != [4, hunits, embedd_dim]
211            || fw_u.dims != [4, hunits, hunits]
212            || fw_b.dims != [4, hunits]
213            || bw_w.dims != [4, hunits, embedd_dim]
214            || bw_u.dims != [4, hunits, hunits]
215            || bw_b.dims != [4, hunits]
216            || time_w.dims != [2, 4, hunits]
217            || time_b.dims != [4]
218        {
219            return Err(DataError::custom("LSTM dimension mismatch"));
220        }
221
222        #[cfg(debug_assertions)]
223        if !dic.iter_copied_values().all(|(_, g)| g < dic_len) {
224            return Err(DataError::custom("Invalid cluster id"));
225        }
226
227        Ok(Self {
228            model,
229            dic,
230            embedding,
231            fw_w,
232            fw_u,
233            fw_b,
234            bw_w,
235            bw_u,
236            bw_b,
237            time_w,
238            time_b,
239        })
240    }
241}
242
243#[cfg(feature = "serde")]
244impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> {
245    fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
246    where
247        S: serde::de::Deserializer<'de>,
248    {
249        #[derive(serde::Deserialize)]
250        struct Raw<'data> {
251            model: ModelType,
252            #[cfg_attr(feature = "serde", serde(borrow))]
253            dic: ZeroMap<'data, PotentialUtf8, u16>,
254            #[cfg_attr(feature = "serde", serde(borrow))]
255            embedding: LstmMatrix2<'data>,
256            #[cfg_attr(feature = "serde", serde(borrow))]
257            fw_w: LstmMatrix3<'data>,
258            #[cfg_attr(feature = "serde", serde(borrow))]
259            fw_u: LstmMatrix3<'data>,
260            #[cfg_attr(feature = "serde", serde(borrow))]
261            fw_b: LstmMatrix2<'data>,
262            #[cfg_attr(feature = "serde", serde(borrow))]
263            bw_w: LstmMatrix3<'data>,
264            #[cfg_attr(feature = "serde", serde(borrow))]
265            bw_u: LstmMatrix3<'data>,
266            #[cfg_attr(feature = "serde", serde(borrow))]
267            bw_b: LstmMatrix2<'data>,
268            #[cfg_attr(feature = "serde", serde(borrow))]
269            time_w: LstmMatrix3<'data>,
270            #[cfg_attr(feature = "serde", serde(borrow))]
271            time_b: LstmMatrix1<'data>,
272        }
273
274        let raw = Raw::deserialize(deserializer)?;
275
276        use serde::de::Error;
277        Self::try_from_parts(
278            raw.model,
279            raw.dic,
280            raw.embedding,
281            raw.fw_w,
282            raw.fw_u,
283            raw.fw_b,
284            raw.bw_w,
285            raw.bw_u,
286            raw.bw_b,
287            raw.time_w,
288            raw.time_b,
289        )
290        .map_err(|_| S::Error::custom("Invalid dimensions"))
291    }
292}
293
294#[cfg(feature = "datagen")]
295impl databake::Bake for LstmDataFloat32<'_> {
296    fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
297        let model = self.model.bake(env);
298        let dic = self.dic.bake(env);
299        let embedding = self.embedding.bake(env);
300        let fw_w = self.fw_w.bake(env);
301        let fw_u = self.fw_u.bake(env);
302        let fw_b = self.fw_b.bake(env);
303        let bw_w = self.bw_w.bake(env);
304        let bw_u = self.bw_u.bake(env);
305        let bw_b = self.bw_b.bake(env);
306        let time_w = self.time_w.bake(env);
307        let time_b = self.time_b.bake(env);
308        databake::quote! {
309            icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked(
310                #model,
311                #dic,
312                #embedding,
313                #fw_w,
314                #fw_u,
315                #fw_b,
316                #bw_w,
317                #bw_u,
318                #bw_b,
319                #time_w,
320                #time_b,
321            )
322        }
323    }
324}
325
326#[cfg(feature = "datagen")]
327impl databake::BakeSize for LstmDataFloat32<'_> {
328    fn borrows_size(&self) -> usize {
329        self.model.borrows_size()
330            + self.dic.borrows_size()
331            + self.embedding.borrows_size()
332            + self.fw_w.borrows_size()
333            + self.fw_u.borrows_size()
334            + self.fw_b.borrows_size()
335            + self.bw_w.borrows_size()
336            + self.bw_u.borrows_size()
337            + self.bw_b.borrows_size()
338            + self.time_w.borrows_size()
339            + self.time_b.borrows_size()
340    }
341}
342
343/// The data to power the LSTM segmentation model.
344///
345/// This data enum is extensible: more backends may be added in the future.
346/// Old data can be used with newer code but not vice versa.
347///
348/// Examples of possible future extensions:
349///
350/// 1. Variant to store data in 16 instead of 32 bits
351/// 2. Minor changes to the LSTM model, such as different forward/backward matrix sizes
352///
353/// <div class="stab unstable">
354/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
355/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
356/// to be stable, their Rust representation might not be. Use with caution.
357/// </div>
358#[derive(Debug, PartialEq, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
359#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake))]
360#[cfg_attr(feature = "datagen", databake(path = icu_segmenter::provider))]
361#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
362#[yoke(prove_covariance_manually)]
363#[non_exhaustive]
364pub enum LstmData<'data> {
365    /// The data as matrices of zerovec f32 values.
366    Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>),
367    // new variants should go BELOW existing ones
368    // Serde serializes based on variant name and index in the enum
369    // https://docs.rs/serde/latest/serde/trait.Serializer.html#tymethod.serialize_unit_variant
370}
371
372icu_provider::data_struct!(
373    LstmData<'_>,
374    #[cfg(feature = "datagen")]
375);