While deep neural networks (DNNs) have achieved astonishing performance in solving complex real-life problems, training a good DNN has become increasingly challenging, as it is difficult to ensure the optimizers used will converge to reliable minima with satisfactory model performance when only minimizing the conventional empirical loss.
In the new paper SS-SAM: Stochastic Scheduled Sharpness-Aware Minimization for Efficiently Training Deep Neural Networks, a Tsinghua University research team proposes Stochastic Scheduled SAM (SS-SAM), a novel and efficient DNN training scheme. Compared to baseline sharpness-aware minimization (SAM), the team’s approach achieves comparable or better model training performance at a much lower computation cost.
To maximize a model’s generalization ability, it is crucial that DNN training efficiently converges to flat minima. Previous research has introduced a sharpness-aware minimization (SAM) technique, which encourages optimizers to converge to flat minima by minimizing the loss sharpness. The SAM computational cost however can reach double that of standard stochastic gradient descent (SGD) training, as additional forward-backward propagation is required for each parameter update during the SAM training procedure. The proposed SS-SAM is designed to address this SAM training limitation, aiming to reduce the overall computational overhead while preserving model generalization ability.
In SS-SAM, optimizers perform a Bernoulli trial with a scheduling function at each step, randomly performing either SGD or SAM optimization with a probability determined by a predefined custom scheduling function. Specifying different scheduling functions thus enables the number of forward-backward propagations to be controlled.
In their empirical study, the researchers examined four types of scheduling functions (constant functions, piecewise functions, linear functions and trigonometric functions), noting their expected propagation count and their computational efficiency and impact on model performance.
The results show that by using specific scheduling functions, models can achieve results comparable to SAM with an average propagation count of only 1.5, representing a significant speedup; and that proper scheduling functions can also boost model performance at a much lower computation cost.
The team says future work in this area could focus on exploiting more appropriate scheduling functions to further improve computational efficiency and model generalization.
The paper SS-SAM: Stochastic Scheduled Sharpness-Aware Minimization for Efficiently Training Deep Neural Networks is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.