Skip to content

Memory optimization for loading weights for no_std mode #2871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
HerrMuellerluedenscheid opened this issue Mar 5, 2025 · 8 comments
Open
Labels
enhancement Enhance existing features

Comments

@HerrMuellerluedenscheid
Copy link

Hey folks,

Thanks for building a genius framework! From a recent issue, you probably remember that I tried running the SqueezeNet example on an ESP32. I switched to an ESP32-S3 with 8 MB PSRAM. After failing to run it there due to allocation failures, I started a discussion in the #esp-rs:matrix.org chat, which turned out to be super fruitful (BIG shoutout!).

A few key findings and questions that I will try to summarize from the thread:

  1. Burn starts by taking the neural network, pushing it into a Vec, and then cloning it—leading to 5 MB of RAM usage for something that’s already readable in flash. More specifically, the first part is the generated squeezenet1.rs (generated in target//debug/build/squeezenet-burn-whatever/out/model):
impl<B: Backend> Model<B> {
    pub fn from_embedded(device: &B::Device) -> Self {
        let record = BinBytesRecorder::<HalfPrecisionSettings>::default()
            .load(EMBEDDED_STATES.to_vec(), device) // <-- here is an allocation that shouldn't be needed
            .expect("Should decode state successfully");
        Self::new(device).load_record(record)
    }
}
  1. burn-core, Recorder::load clones the model again. Is that really necessary?
    fn load<R>(&self, args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>
    where
        R: Record<B>,
    {
        let item: BurnRecord<R::Item<Self::Settings>, B> =
            self.load_item(args.clone()).map_err(|err| { // <-- here
  1. can burn operate directly off from the EMBEDDED_STATES, i.e. without copying the model to RAM

  2. Ideally they should split their model into "readonly" stuff and "readwrite" stuff and then the "readonly" stuff is used as-is. I.e. flashed un-decoded.

Apparently there is room for some improvement to run models in no_std which I guess will also be beneficial when running with std.
Looking forward hearing your thoughts.

@antimora
Copy link
Collaborator

antimora commented Mar 5, 2025

Yes, I agree there is a lot of room to improve for loading weights. We haven't focused on this initially. But I think this is perfect time. Since you're in the weeds of it, it would be great if you try using more efficient Rust APIs to consume the existing preallocated memory without duplicating (albeit temporarily). I know there are some. We will review your PR.

@antimora antimora added the enhancement Enhance existing features label Mar 5, 2025
@antimora antimora changed the title optimize burn for no_std Memory optimization for loading weights for no_std mode Mar 5, 2025
@BjornTheProgrammer
Copy link
Contributor

I have the exact same issues, and I noticed the same things with the Raspberry Pi Pico, I would be willing to tackle this issue as well with a team that I'm working with. Is there an active PR for this, or should I create one?

@HerrMuellerluedenscheid
Copy link
Author

HerrMuellerluedenscheid commented Mar 9, 2025

@BjornTheProgrammer I just pushed some experiments to #2881. On my ESP32s3 it does not panic because of allocation failures.
I'm getting this error now instead:
/crates/burn-core/src/record/memory.rs:39:85: called Result::unwrap()on anErrvalue: InvalidIntegerType { expected: U32, found: Reserved }
But I consider this already some success with regard to the memory. So, feel free to take a look, modify, be inspired :) Would love to have this working on some mcus.

@ionspin
Copy link

ionspin commented Mar 16, 2025

This is what I'm seeing using BinBytesRecorder to load a ~256kb model. I presume some of it comes from bincode deserializing.

Before loading: Stats {
    allocations: 0,
    deallocations: 0,
    reallocations: 0,
    bytes_allocated: 0,
    bytes_deallocated: 0,
    bytes_reallocated: 0,
}
Stats at 1: Stats {
    allocations: 343,
    deallocations: 246,
    reallocations: 0,
    bytes_allocated: 1072824,
    bytes_deallocated: 802391,
    bytes_reallocated: 0,
}

I've been trying to add a zero-copy serialization like rkyv, but it's not that straightforward as rykv requires Archive annotation, and I can't figure out where in the macros is the Serialze annotation added to the modules or how is it handled. I'd like to have pub trait Recorder... save_item and load_item to become fn save_item<I: Serialize + Archive>(....

@BjornTheProgrammer
Copy link
Contributor

BjornTheProgrammer commented Mar 16, 2025

Wow! This is a great find @ionspin! I'll experiment using rykv on my fork as well in #2892. Maybe we could approach it by first creating a new recorder.

@ionspin
Copy link

ionspin commented Mar 16, 2025

Wow! This is a great find @ionspin! I'll experiment using rykv on my fork as well in #2892. Maybe we could approach it by first creating a new recorder.

I agree, and that's exactly where I got stuck (and please note that I've never used rkyv before, so I might be going the wrong way) because Recorder interface has a Serialize bound in:

fn save_item<I: Serialize>(...

and similar in counterpart load_item and rkyv requires Archive. And I think there are quite a few places that would require adding Archive as annotation, so I'd like to figure those out before continuing.

Is there a way to use rkyv without Archive?

I'm almost tempted to try and grab the weights tensors directly from my model and serialize and deserialize them myself as a quick hack.

@ionspin
Copy link

ionspin commented Mar 16, 2025

Oh and also note, that I'm not claiming that all of those allocations I posted come from bincode, I haven't looked closely at what happens after the bytes are deserialized by bincode and there probably is more allocations/deallocations there.

@BjornTheProgrammer
Copy link
Contributor

My PR #2892 has been merged! The memory savings weren't quite what I expected. I believe that the most optimization will probably come from the executor backend. I'm going to try to create some tooling or process for inspecting memory usage in depth to really discover where the best savings can come from.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

4 participants