binaryninja/
progress.rs

1use std::ffi::c_void;
2
3pub trait ProgressCallback: Sized {
4    type SplitProgressType: SplitProgressBuilder;
5
6    /// Caller function will call this to report progress.
7    ///
8    /// Return `false` to tell the caller to stop.
9    fn progress(&mut self, progress: usize, total: usize) -> bool;
10
11    unsafe extern "C" fn cb_progress_callback(
12        ctxt: *mut c_void,
13        progress: usize,
14        total: usize,
15    ) -> bool {
16        let ctxt: &mut Self = &mut *(ctxt as *mut Self);
17        ctxt.progress(progress, total)
18    }
19
20    #[allow(clippy::wrong_self_convention)]
21    unsafe fn into_raw(&mut self) -> *mut c_void {
22        self as *mut Self as *mut c_void
23    }
24
25    /// Split a single progress function into proportionally sized subparts.
26    /// This function takes the original progress function and returns a new function whose signature
27    /// is the same but whose output is shortened to correspond to the specified subparts.
28    ///
29    /// The length of a subpart is proportional to the sum of all the weights.
30    /// E.g. with `subpart_weights = &[ 25, 50, 25 ]`, this will return a function that calls
31    /// progress_func and maps its progress to the ranges `[0..=25, 25..=75, 75..=100]`
32    ///
33    /// Weights of subparts, described above
34    ///
35    /// * `progress_func` - Original progress function (usually updates a UI)
36    /// * `subpart_weights` - Weights of subparts, described above
37    fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType;
38}
39
40pub trait SplitProgressBuilder {
41    type Progress<'a>: ProgressCallback
42    where
43        Self: 'a;
44    fn next_subpart(&mut self) -> Option<Self::Progress<'_>>;
45}
46
47impl<F> ProgressCallback for F
48where
49    F: FnMut(usize, usize) -> bool,
50{
51    type SplitProgressType = SplitProgress<F>;
52
53    fn progress(&mut self, progress: usize, total: usize) -> bool {
54        self(progress, total)
55    }
56
57    fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
58        SplitProgress::new(self, subpart_weights)
59    }
60}
61
62pub struct NoProgressCallback;
63
64impl ProgressCallback for NoProgressCallback {
65    type SplitProgressType = SplitProgressNop;
66
67    fn progress(&mut self, _progress: usize, _total: usize) -> bool {
68        unreachable!()
69    }
70
71    unsafe extern "C" fn cb_progress_callback(
72        _ctxt: *mut c_void,
73        _progress: usize,
74        _total: usize,
75    ) -> bool {
76        true
77    }
78
79    fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
80        SplitProgressNop(subpart_weights.len())
81    }
82}
83
84pub struct SplitProgressNop(usize);
85
86impl SplitProgressBuilder for SplitProgressNop {
87    type Progress<'a> = NoProgressCallback;
88
89    fn next_subpart(&mut self) -> Option<Self::Progress<'_>> {
90        if self.0 == 0 {
91            return None;
92        }
93        self.0 -= 1;
94        Some(NoProgressCallback)
95    }
96}
97
98pub struct SplitProgress<P> {
99    callback: P,
100    subpart_weights: &'static [usize],
101    total: usize,
102    progress: usize,
103}
104
105impl<P: ProgressCallback> SplitProgress<P> {
106    pub fn new(callback: P, subpart_weights: &'static [usize]) -> Self {
107        let total = subpart_weights.iter().sum();
108        Self {
109            callback,
110            subpart_weights,
111            total,
112            progress: 0,
113        }
114    }
115
116    pub fn next_subpart(&mut self) -> Option<SplitProgressInstance<'_, P>> {
117        if self.subpart_weights.is_empty() {
118            return None;
119        }
120        Some(SplitProgressInstance { progress: self })
121    }
122}
123
124impl<P: ProgressCallback> SplitProgressBuilder for SplitProgress<P> {
125    type Progress<'a>
126        = SplitProgressInstance<'a, P>
127    where
128        Self: 'a;
129    fn next_subpart(&mut self) -> Option<Self::Progress<'_>> {
130        self.next_subpart()
131    }
132}
133
134pub struct SplitProgressInstance<'a, P: ProgressCallback> {
135    progress: &'a mut SplitProgress<P>,
136}
137
138impl<P: ProgressCallback> Drop for SplitProgressInstance<'_, P> {
139    fn drop(&mut self) {
140        self.progress.progress += self.progress.subpart_weights[0];
141        self.progress.subpart_weights = &self.progress.subpart_weights[1..];
142    }
143}
144
145impl<P: ProgressCallback> ProgressCallback for SplitProgressInstance<'_, P> {
146    type SplitProgressType = SplitProgress<Self>;
147
148    fn progress(&mut self, progress: usize, total: usize) -> bool {
149        let subpart_progress = (self.progress.subpart_weights[0] * progress) / total;
150        let progress = self.progress.progress + subpart_progress;
151        self.progress
152            .callback
153            .progress(progress, self.progress.total)
154    }
155
156    fn split(self, subpart_weights: &'static [usize]) -> Self::SplitProgressType {
157        SplitProgress::new(self, subpart_weights)
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use std::cell::Cell;
164
165    use super::*;
166
167    #[test]
168    fn progress_simple() {
169        let progress = Cell::new(0);
170        let mut callback = |p, _| {
171            progress.set(p);
172            true
173        };
174        callback.progress(0, 100);
175        assert_eq!(progress.get(), 0);
176        callback.progress(1, 100);
177        assert_eq!(progress.get(), 1);
178        callback.progress(50, 100);
179        assert_eq!(progress.get(), 50);
180        callback.progress(99, 100);
181        assert_eq!(progress.get(), 99);
182        callback.progress(100, 100);
183        assert_eq!(progress.get(), 100);
184    }
185
186    #[test]
187    fn progress_simple_split() {
188        let progress = Cell::new(0);
189        let callback = |p, _| {
190            progress.set(p);
191            true
192        };
193        let mut split = callback.split(&[25, 50, 25]);
194        // 0..=25
195        let mut split_instance = split.next_subpart().unwrap();
196        split_instance.progress(0, 100);
197        assert_eq!(progress.get(), 0);
198        split_instance.progress(100, 100);
199        assert_eq!(progress.get(), 25);
200        drop(split_instance);
201
202        // 25..=75
203        let mut split_instance = split.next_subpart().unwrap();
204        split_instance.progress(0, 100);
205        assert_eq!(progress.get(), 25);
206        split_instance.progress(25, 100);
207        // there is no way to check for exact values, it depends on how the calculation is done,
208        // at the time or writing of this test is always round down, but we just check a range because this
209        // could change
210        assert!((36..=37).contains(&progress.get()));
211        split_instance.progress(50, 100);
212        assert_eq!(progress.get(), 50);
213        split_instance.progress(100, 100);
214        assert_eq!(progress.get(), 75);
215        drop(split_instance);
216
217        // 75..=100
218        let mut split_instance = split.next_subpart().unwrap();
219        split_instance.progress(0, 100);
220        assert_eq!(progress.get(), 75);
221        split_instance.progress(100, 100);
222        assert_eq!(progress.get(), 100);
223        drop(split_instance);
224
225        assert!(split.next_subpart().is_none());
226    }
227
228    #[test]
229    fn progress_recursive_split() {
230        let progress = Cell::new(0);
231        let callback = |p, _| {
232            progress.set(p);
233            true
234        };
235        let mut split = callback.split(&[25, 50, 25]);
236        // 0..=25
237        let mut split_instance = split.next_subpart().unwrap();
238        split_instance.progress(0, 100);
239        assert_eq!(progress.get(), 0);
240        split_instance.progress(100, 100);
241        assert_eq!(progress.get(), 25);
242        drop(split_instance);
243
244        // 25..=75, will get split into two parts: 25..=50 and 50..=75
245        {
246            let split_instance = split.next_subpart().unwrap();
247            let mut sub_split = split_instance.split(&[50, 50]);
248            // 25..=50
249            let mut sub_split_instance = sub_split.next_subpart().unwrap();
250            sub_split_instance.progress(0, 100);
251            assert_eq!(progress.get(), 25);
252            sub_split_instance.progress(100, 100);
253            assert_eq!(progress.get(), 50);
254            drop(sub_split_instance);
255
256            // 50..=75
257            let mut sub_split_instance = sub_split.next_subpart().unwrap();
258            sub_split_instance.progress(0, 100);
259            assert_eq!(progress.get(), 50);
260            sub_split_instance.progress(100, 100);
261            assert_eq!(progress.get(), 75);
262            drop(sub_split_instance);
263        }
264
265        // 75..=100
266        let mut split_instance = split.next_subpart().unwrap();
267        split_instance.progress(0, 100);
268        assert_eq!(progress.get(), 75);
269        split_instance.progress(100, 100);
270        assert_eq!(progress.get(), 100);
271        drop(split_instance);
272
273        assert!(split.next_subpart().is_none());
274    }
275}