Introduction

Ensembles have been proven to be very effective for various regression and classification tasks [1]. Having a cohort of networks, each situated on a different local minimum [2], hence performing better on a slightly different part from the input space [3], a natural question arises:can we select which network/mixture of networks to use for a specific input? The answer comes from the Mixture of Experts (MoE) frameworks [4], where an extra selection function, or gater, is trained to combine or drop some individual predictions. Such systems open the path for further optimizations in terms of both accuracy and speed.

This paper explores the problem of computational efficiency in ensemble frameworks. To this end, we build an ensemble of sub-problem specialized models (branches), split the input space into multiple partitions, enforce a specialization of the models for each partition, and train a gating mechanism to decide which branch should be used for each input sample. Similar to [5] we take the top-K predictions of the gater to combine the branches projections to obtain the final classification. The conditional execution is not new in the literature, but our original contributions include (1) an extra step to enforce individual model specialization and (2) a dynamical exclusion of some of the top-K predictions, based on how confident the gater is (i.e., the confidence is high). For example, if we have 20 branches, and K is 4, by applying the dynamic exclusion step we may end up with just 1 or 2 active branches. This allows us to reduce the computational cost. The proposed specialization step pushes each branch to extract discriminative features targeting only their assigned specific task, assuring that by further excluding some branches the final prediction is preserved (i.e., a single branch is specialized to target a prediction class, and as such is sufficient to obtain an accurate prediction).

We evaluate our module DynK-Hydra,Footnote 1 on the CIFAR-100 [6], Food-101 [7], CUB-200 [8], and Imagenet32 [9] data sets and we get an improvement of 4.3% mean accuracy compared to the widely employed Resnet [10] architectures while reducing the inference flops by a factor of [2–5.5] times. By comparing our model to the HydraRes framework [5] we obtain a marginal 1.2% accuracy improvement on pairwise architectures, but we reduce the inference time by up to a factor of 2.8.

Related work

Training the same model multiple times, on the same data set, will produce a cohort of models, all situated in a different local minimum [2, 3]. This implies a better performance on a slightly different subset of the data set, for each of the models. A larger model (namely, ensemble), created from a combination of these sub-models, can benefit from this diversity and as such exhibit a boosted accuracy.

Ensembles

From an empirical point of view, many frameworks have benefited from aggregating multiple features, ranging from ensembling models [4, 11] to graph neural networks (GNNs) [12, 13]. In [12] principal component analysis is used as an aggregator function for the feature nodes of a GNN, to preserve the real signals from neighboring features and to simultaneously filter out the Gaussian noise. In [13] the authors state that injectivity is necessary to improve the representation capacity of GNNs and proposed several mechanisms for converting the non-injective aggregation functions into injective ones. A concrete study is performed on combining the injective aggregation function with the node feature encodings exhibited by the GNNs.

In the case of model ensembling, there are many ways (or categories of problems) to obtain the desired ensemble (accurate, with a high diversity), but the most common are to vary the models architectures, the training data, or to combine the outputs differently.

From the first category, [14] dynamically creates an ensemble of regressors, with emphasis on both accuracy and diversity. The training starts with an ensemble of two Neural Networks (NN), having a single hidden unit. At each training step, if the ensemble is below a threshold error, a new Neural Network (NN) is inserted or more hidden units are added until the branches reach a halting criterion (usually the accuracy is not increasing anymore).

Varying the training data involves splitting the data set, each model seeing different subsets or different augmentations subsets. In this category, there are algorithms, such as boosting [15], bagging [16], and stacking [17].

The final output is usually obtained either based on a fusion-based method or on a selection-based method. From the first category, one can use the majority voting or mean predictions, which treats all the classifiers/regressors equally. Snapshot Ensembling [18] is a method, where a network is trained for a longer number of steps and each time it converges, a snapshot is saved (hence a new model is obtained), and the learning rate is reset. In the Deep Mutual Learning (DML) [1] framework a cohort of similar models are trained together with an extra loss for knowledge distillation, which assures they are learning from each other. In the selection-based category, a set of weights (in the extreme case—binary) are assigned for each prediction. One specific case is MoE [4], which divides the problem space between multiple models and trains a meta-model to combine them (usually a gating network). We will discuss more about this in Sect. 2.2.

A great deal of interest was given to the concept of diversity, how to measure it and how it influences the ensemble’s performance. For regression problems, [19] demonstrates that the quadratic error of an ensemble is less or equal to the average quadratic error of its components. In the classification case, there are not many theoretical results, mostly because it is hard to define the diversity in this situation. For example, [20] categorizes the ensemble’s diversity based on the coincident error, which represents the number of classifiers that give the wrong answer. Based on this, the authors define four levels: (1) no coincident error, (2) some coincident error, but the majority voting is always correct, (3) at least one model is correct, but the majority voting is not always correct, (4) there are samples for which no model is correct. This ranking is not always very useful, but [20] concludes with some important remarks: having an ensemble in Level 2, there exists a subset of members (or models) which takes the ensemble to Level 1, and in the same manner, an ensemble in Level 3 can upgrade to Level 2. More specifically, if we select a subset of predictions, we can increase the probabilities that our ensemble gives the right result. These observations drive us to investigate further the gating mechanism, as described in Sect. 2.2.

Gating mechanisms

In an ensemble framework, we call a gater a network that analyses the input samples, the models’ intermediate predictions [5], or both, and determines an efficient way to combine (or drop) [21] a part of the models’ predictions, to obtain a better ensembling result. This mechanism is used also outside the ensemble scope. For example, in [22, 23] the gating mechanism is applied at the level of intermediate activation maps within a neural network; in [22] a backbone network (used for a specific task) is accompanied by a gater network, which learns what filters from the backbone to discard (set to 0).

In this work, we are more interested in MoE frameworks, where a separate predictor is trained to decide how to select and combine the prediction branches [5, 11, 21, 24]. Basically, on a sample I, \(n_b\) predictors P(I), and a gater G(I) (which is an array of size \(n_b\), with values \(\in [0,1]\)), the output of the ensemble E is:

$$\begin{aligned} E(I) = \sum _{i=1}^{n_b} G(I)_i P_i(I). \end{aligned}$$
(1)

In On the fly Native Ensemble (ONE) [21] several experts and the gater share the low-level layers (similar to [5]) and form together a self-gating ensemble. The gater decides which branches’ predictions are discarded or kept in the final prediction. Finally, using a knowledge-distillation loss, the ensemble prediction is distilled back to all the composing models.

An instance of conditional computation, namely, Deep Mixture of Experts (DMoE), is introduced in [11], where an input-based gating network learns to combine a list of experts. The gating framework is deep, because it works at multiple levels within the networks.

In the HydraRes [5] and Sparsely Gated MoE [24] frameworks, a large number of models (experts) are trained to solve a classification problem. The gater is used to make a top-K selection of predictors, which are further used in the final decision. In the case of HydraRes, each branch predicts an embedding array and not a final set of logits. Furthermore, each model is encouraged to specialize on a specific part of the data set, with the sub-partitions being a-priori determined. The number of experts involved in their experiments varies from 5 to 50. The work of [24] is applied for a language modeling task, where the number of experts varies from 32 to 4096. The model consists of a word embedding layer, two LSTMs (Long–Short-Term-Memory layers), a MoE layer between them, and a softmax layer.

A major issue when designing the gater is that it tends to produce large weights for a small subset of experts that converge faster [11, 24]. In [11] the problem is mitigated by enforcing an equal distribution of the expert usage in the initial training phase, and in [24] the issue is tackled with an extra loss term, which encourages all experts to be equally important. Due to the experts’ specialization, this problem is automatically alleviated in [5]. We observe the same effect in our case and we further detail it in Sect. 3.

To conclude, MoE frameworks have the following advantages [25]: efficiency–selective models activation, representation power–possibility to have a large number of parameters parallelizable at training time, adaptability–adaptable to hardware constraints, compatibility–most existing machine learning techniques can integrate MoE, genericity–used for tasks in computer vision, language modeling, etc, interpretability–possibility to analyze the active responding models.

Proposed method

In this paper we propose the DynK-Hydra framework, tested on multiple image classification data sets (CIFAR-100 [6], Food-101 [7], CUB-200 [8], and Imagenet32 [9]) composed of a cohort of \(n_b\) models (also called branches) and a gater G. As in [5], the low-level layers (the stem network) are shared among all branches and the gater. For each input image I, the gater predicts a set of weights \(w_i, i \in \{1, n_b\}\). Each value from \(w_i\) represents the gater’s affinity probability to select the branch i for the final prediction. The individual branches predict a set of embeddings (or projections) \(proj_i(I)\), which are further combined, based on the gater’s decision G(I) = {\(w_i, i \in \{1, n_b\}\)}. The top-K G(I) selections of embeddings are summed and passed through a combiner and then fed into a classification layer, which outputs the final ensemble prediction E(I) as in Equ. 2.

$$\begin{aligned} E(I) = fc ( GAP (\sum _{b \in topK(G(I))} proj_b(I))), \end{aligned}$$
(2)

where fc is the fully connected classification layer, and GAP is the global average pooling operator.

The entire inference process is formalized in Algorithm 1.

figure a

Upon training, each branch should be specialized on a different partition of the data set. This reduces the complexity of the classification task for each branch. Assuming we have M classes, we group them in a number of clusters \(n_b\) (\(n_b < M\)) and each branch is assigned to one cluster (more details in Sect. 3.1). So far, the system is similar to HydraRes [5], but on top of that, we are adding two original contributions: (1) if in [5] the hyper-parameter K (from the top-K selection) is set to a maximum value (usually 4), in our case, it is dynamically decreased based on the Gater’s prediction and (2) we employ an extra loss term which assures that each branch is specialized on the given sub-partition, regardless of the gater selection (see Sect. 3.2). The first contribution implements an extra branch selection mechanism, which reduces the inference time and the usage of unnecessary extracted features by cutting out complete branches, while the second enforces the specialization of the branches, regardless of the gater activation.

Class partitioning

The main purpose of partitioning is to group the input space in a cohort of classes (or clusters) which are similar from the model’s perspective. Each branch is responsible for extracting discriminative features from the given cluster.

For data sets with a small number of classes, manual clustering is possible, but even so, it is not guaranteed that the chosen classes (for each cluster) are representative (similar) in the model’s space. From this point of view, we follow a similar process to the one described in [5].

The class partitioning process is described in Algorithm 2.

figure b

First, we use the state of the art EfficientNet-B7 [26] (pretrained on ImageNet [9]) architecture to predict a set of embeddings (\(\in {\mathbb {R}}^{2560}\)) for each image (2560 is the dimensionality of EfficientNet-B7 embedding vector). Then, we compute a class representative R as the mean of all embeddings from the images within that class. As an example, in the case of CIFAR-100 data set, we end up with 100 embeddings corresponding to the classes in this data set. The next step is to cluster these representatives, using the kMeans algorithm to create \(n_b\) clusters. For \(n_b=10\) the clusters of embeddings are illustrated in Fig. 1, where each color represents a cluster of 10 classes, and each point is an image embedding vector. It can be observed that the images are grouped adequately in the coarse labels space.

Fig. 1
figure 1

Subtask partitioning (best viewed in color). Each point represents an image and the associated color is a one of the \(n_b=10\) clusters. For dimensionality reduction (from \({\mathbb {R}}^{2560}\) to 3D), we used t-SNE [27]

The next step is to compute the class partitioning P, by assigning \(M/n_b\) classes to each cluster. To ensure a balanced result, each cluster center is assigned to its nearest class, and this process is repeated \(M/n_b\) times, similar to [5]. Finally, each class partition is allocated to one of the \(n_b\) branches of the architecture.

Subtask specialization

Each branch from the architecture predicts an embedding (also called projection) and the gater is responsible to select a subset of branches to create the final prediction. Therefore, the system accuracy depends both on branch projections and on the discriminative power of the gater. The latter can be seen as a predictor on the coarse classification task, where the classes are selected as described in Sect. 3.1. Ideally, the branches are strictly specialized on the given sub-task, and the gater selects a single embedding for performing the final classification. As described in [5], to decouple the dependency of the final accuracy and the gating function G, we do not rely only on the best selected branch, but on the top-K predictions of G, where k is a fixed parameter in [5] and a dynamically determined, image-dependent variable in our case.

In [5], two Cross-Entropy losses are guiding the training process, one for the final classification and one for the gater. In the same manner, we use the Cross-Entropy loss to optimize the final prediction and the gater selection. The ensemble loss (i.e., final prediction loss) (\(L_E\)) is presented in Equ. 3.

$$\begin{aligned} \small L_E = \frac{1}{N}\sum _i^N -\log \left( \frac{e^{y_i^{gt}}}{\sum _j^c, e^{y_i^{j}}}\right) , \end{aligned}$$
(3)

where N is the n umber of samples in the mini-batch, c is the number of classes, \(y_i\) is the prediction vector associated with the \(i^{th}\) sample, \(y_i^{j}\) is the prediction for the class j, and gt is the index of the ground truth class.

The second loss \(L_G\) enforces the gater to activate just the correct branch, based on the predefined coarse labels (see Sect. 3.2). The gater loss function is the Cross Entropy on the coarse labels:

$$\begin{aligned} \small L_{G} = \frac{1}{N}\sum _i^{N} -\log \left( \frac{e^{y_i^{gt\_coarse}}}{\sum _j^{coarse\_labels} e^{y_i^{j}}}\right) , \end{aligned}$$
(4)

where the \(coarse\_labels\) are the newly extracted labels (super classes), as explained in Sect. 3.2, and the rest of the notations are the same as in Equ. 3.

During back-propagation, the weights of the branches that are not involved in the final prediction are not altered, and the independent weights of the gater are not changed due to the ensemble prediction error. They are only altered based on the \(L_{branch}\) loss (see Paragraph 3.3).

We illustrate the forward and backward propagation flow in Fig. 2.

Fig. 2
figure 2

Backpropagation flow (best viewed in color). The three losses are guiding the entire training step: for the final prediction (with red) \(L_{E}\), for the gater \(L_G\)(with orange), and for the individual branches \(L_{branch}\)(blue). The dotted lines represent the backpropagation pass corresponding to each loss. In this example, the gater selects the branches 2 and nb, so the independent layers of branch 1 are not altered by \(L_{E}\) at this specific step. The independent gater layers are also not dependent on the final classification performance

When dealing with ensembles, the diversity of the members is an actively discussed topic in literature, especially its connection to the overall accuracy [3, 20, 28]. As we have presented in Sect. 2, the main benefit of using an ensemble, is that each sub-model is converging to a different local minimum, and hence it performs better on a slightly different part of the data set. In our case, as for [5], this diversity is maximized, because each branch is specifically specialized on a certain part of the data set.

Enforcing specialization

In the previously described back-propagation flow, the branches are receiving feedback only from the samples which are assigned to them, but only if the gater is choosing them to participate to the final prediction (which is not always true, especially on the first part of the training). A contribution of our work is that we are directly guiding the individual branches to learn the specific features of the given sub-task. This loss can be seen as an extra sub-task specialization step for each branch.

Specialization enforcement is achieved by adding an extra classification layer on top of the projection output for each branch. The classification layer outputs a probability distribution for all the data set classes, but the loss associated with it provides feedback only for its sub-partition classes. The intuition is that by having a second task to learn how to discriminate over the given classes, each branch will be specialized to extract descriptive features, regardless of the gating performance.

The extra loss on each individual branch is a the Categorical Cross-Entropy loss applied on the selected branches \(Part_{branch}\). For each branch, the loss is calculated only on the samples belonging to a class included in the partition of that individual model. This is mathematically described as:

$$\begin{aligned} \small L_{branch} = \frac{1}{|Part_{branch} |}\sum _{i \in {Part_{branch}}} -\log \left( \frac{e^{y_i^{gt}}}{\sum _j^c e^{y_i^j}}\right) , \end{aligned}$$
(5)

where \(Part_{branch}\) are the selected mini-batch indices, which corresponds to the branch partition. The rest of the notations are the same as in Equ. 3.

Eliminate waste

At this step, if each branch is specialized on a specific partition of the data set, one can question the relevance of the rest of K-1 branches, which are included in the final prediction. During the evaluation, we have observed that some gating predictions (from the top-K selection) are close to zero, so we conclude that some embeddings (branch outputs) are not required for that specific input. The extra K-1 embeddings may introduce unnecessary information and can deteriorate the output. An original contribution of our work is that we dynamically select a subset of branches, out of the top-K gater predictions, based on how confident the gater is on its output. We are doing this by setting a threshold \(\tau \) on the (softmax) gater prediction, and remove the branches, where the probability distribution is less than \(\tau \). As a result, in our framework, Eq. 2 is replaced by Eq. 6.

$$\begin{aligned} E(I) = fc ( GAP (\sum _{b \in topK(G(I))} dynProj_b(I))), \end{aligned}$$
(6)

where

$$\begin{aligned} dynProj_b(I) = {\left\{ \begin{array}{ll} proj_b(I), &{}proj_b(I) > \tau \\ 0,&{} otherwise \end{array}\right. }. \end{aligned}$$

The reason behind this extra selection step is: (1) reducing the number of computation (flops) in the inference phase, by running a smaller number of specialized branches, and (2) increasing the accuracy of the overall prediction, as only the necessary extracted features are combined. More about this is explained in Sect. 4.

The overall loss function is a weighted sum of the ensemble classification \(L_E\), gater \(L_G\), and individual \(L_{branch}\) losses:

$$\begin{aligned} L_{final} = \alpha _1L_E + \alpha _2L_G + \sum _i^{n_b} \alpha _3L_{branch}, \end{aligned}$$
(7)

where \(\alpha _1\), \(\alpha _2\), and \(\alpha _3\) are the weights of the ensemble, gater and individual branches losses, respectively.

Results and discussion

In this section, we discuss the implementation details of our architecture, the training hyper-parameters, the data sets, and the numerical results.

Network architectures

Our architecture DynK-Hydra is inspired by the work [5]. It consists of a number of (\(n_b\)) branches and a gater. They are sharing the low-level convolutional layers (network stem). Basically, the first two blocks of ResNet are part of the stem and the third block is replicated for each individual branch and also for the gater. A fully connected layer (with \(n_b\) neurons) is attached on top of the last block of the gater to compute the gating prediction.

The outputs of individual branches are summed, then passed to a convolutional layer, and then to a prediction layer of size c equal to the number of classes.

During training, as opposed to [5], an extra classification layer is attached on top of each individual branch, which is trained to yield accurate prediction on the assigned data set partition, and thus enforcing branch specialization. To be noted that this extra classification layer does not directly influence the ensemble classification and is removed after training (during inference).

Another different aspect of our framework is that the combiner (which uses the gater predictions to combine the branch projections) does not always take all the top-K projections, but usually less, depending on the confidence of the gater. The parameter d is introduced to specify the number of layers that are stacked in each block. The entire architecture is illustrated in Fig. 3.

Fig. 3
figure 3

DynK-Hydra architecture. As described in [5], the branches and the gater are sharing a common set of convolutional layers. The parameter d specifies the number of times a block is replicated. As opposed to [5], we are adding classification layers on top of each branch and a gating threshold at the level of combiner

Data set

For our experiments we used CIFAR-100 [6], Food-101 [7], CUB-200 [8] and ImageNet [9] data sets. CIFAR-100 consists of 50k images for training and 10k for testing, labeled with 100 classes. The dimensions of each image are 32 \(\times \) 32 with 3 channels (RGB). Food-101 consists of 101k images, organized into 101 categories of food (750 training images and 250 test images per category), with different image resolutions but scaled, such that the maximum side has a length of 512 pixels. We obtain the downsampled version of Food-101 by padding with zeros to get a square image and then resize it to 32 \(\times \) 32. CUB-200 consists of 11k images, from which we take 9k for training. We choose this data set, because it has a large number of classes (i.e., 200). The samples were cropped centrally to create a square and scaled to 224 \(\times \) 224. The ImageNet data set gathers 1.2M images from 1000 classes, and due to limited computational resources available, in our experiment, we use the 32 \(\times \) 32 crops scale.

Even if the CIFAR-100 data set is annotated with coarse semantic labels, which are suitable for 20 branches, we are performing the steps detailed in Sect. 3.1 to get the coarse labels. We do this to show that the manual human semantic coarse-labeling of the data set is not mandatory, and can, on the contrary, be harmful for the final accuracy, as we show in Table 5.

Numerical results

We train the baselines ResNet [10] , DenseNet [29], the original HydraRes and Hydra-Dense, respectively [5], and our DynK-Hydra model (using both ResNet and DenseNet blocks). The ResNet versions were trained on the following data set: CIFAR-100, Food-101, and CUB-200, while the DenseNet versions were trained on CIFAR-100 (we did this to prove the method’s efficacy, regardless of the architecture).

In what follows we describe scenarios we use to train the ResNet architectures (including HydraRes and DynK-Hydra). On the CIFAR-100 data set, we train for 190 epochs, using the AdamW optimizer [30] with an initial learning rate of \(1e^{-3}\). Following a similar pattern as in [10] we schedule a learning rate and decay at epochs 80, 140, 170 and 180 by the following factors [\(1e^{-1}\), \(1e^{-2}\), \(1e^{-3}\), \(0.5e^{-3}\)]. To have a fair comparison with [5] we choose the number of branches \(n_b=20\), and \(K=4\).

For the Food-101 data set, we train for 100 epochs (we use this number, because the number of samples per class is doubled than for CIFAR-100); the learning rate is decayed at epochs 60, 80, 90, and the parameter \(k=4\).

Due to the high computational cost (there are 1.2M images), ImageNet32 was trained for 70 epochs, with a learning rate decay at epochs [40, 60], with factors [\(1e^{-1}\), \(1e^{-2}\)]. The number of branches is 20, and K is set to 5. The batch size is 32 and no cutout processing was employed.

For the CUB-200 data set we use as backbone the ResNet version with the input of 224 \(\times \) 224, and we train for 160 epochs, with learning rate decay at epoch 100, \(n_b=8\) and \(K=4\). Because the network becomes very large, we scale down the architecture and replace the convolution layers with depthwise separable convolutions. The down scaling reduces the output channels (as in [5]) using the following parameters for stem (\(w_s=0.5\)), branches (\(w_b=0.125\)) and gater (\(w_g=0.125\)). As presented in Table 1 the ResNet models, for CUB-200, have a large number of inference flops. They are not scaled down, but the convolutions are replaced by depthwise separable convolutions.

In the final loss equation 7, we need some weights to mediate the contributions of each individual term. The best found weights loss are \(\alpha _1=1\) for the final classification, \(\alpha _2=1.5\) for the gater, and \(\alpha _3=0.05\) for individual branches.

The numerical results on CIFAR-100, CUB-200, and Food-101, in terms of overall accuracy, size, and inference flops for different versions of DynK-Hydra, compared to the state of the art HydraRes [5] and ResNet [10] are presented in Table 1 and for ImageNet32 in Table 3. One should compare the results by looking at similar accuracies and the corresponding inference times (flops), or by looking at similar inference times (flops) and compare the obtained accuracy among the three approaches.

On Food-101 and ImageNet32 databases the results are reported on images of 32 \(\times \) 32 size, hence the low accuracy. Despite of the level of accuracy, our purpose is to prove the efficacy of the dynamic selection and specialization enforcement (DynK-Hydra) by comparing to the static selection mechanism, performed by HydraRes [5] or to Resnet [10].

The DenseNet architecture with the corresponding Hydra-Dense and DynK-Hydra were trained on the CIFAR-100 data set for 120 epochs. [5] did not specify how they construct the DenseNet version, so we created a more compact DenseNet model adapted for the CIFAR-100 data set. We create 3 blocks with dense connectivity size of 3, 6, 12 and grow rate=3. For Hydra-Dense and DynK-Hydra each bock is multiplied d times. The stem includes the first 2 blocks, and the third branch is duplicated for each branch and the gater. The numerical results are presented in Table 2.

Table 1 Comparison of our DynK-Hydra with HydraRes [5] and ResNet [10] on CIFAR-100, Food-101 (best viewed in color)
Table 2 Comparison of our DynK-Hydra with Hydra-Dense [5] and DenseNet [29] on CIFAR-100

The number of inference flops for DynK-Hydra is calculated as a sum of \(Gater_{flops}\) + \(Stem_{flops}\) + \(\lambda \) \(Branches_{flops}\). \(\lambda \) is the mean number of branches activated, which is usually less than the hyperparameter K, as opposed to [5], where they always activate K branches. For example, on CIFAR-100 data set, we perform experiments by choosing \(n_b=20\), \(K=4\), \(d\in \{1, ..., 7, 9\}\), and we measure the actual number of branches used for the test data set, and we obtain the following average \(\lambda \) values {3.46, 2.68, 2.34, 2.24, 2.13, 1.98, 1.90, 1.75}. For Food-101 (\(K=4\)) average \(\lambda \) values are {3.05, 2.62, 2.49, 2.43, 2.39, 2.38, 2.42, 2.30}, for CUB-200 (\(K=4\)) average \(\lambda \) values are {2.01, 1.95, 1.99, 1.96, 1.95, 1.78, 1.76, 1.78}, and for ImageNet32 (\(K=5\)) average \(\lambda \) values are {4.50, 4.16, 4.11, 4.07, 4.02}. The DenseNet (\(K=4\)) DynK-Hydra model is subject to the same optimization with the average \(\lambda \) values {3.14, 2.69, 2.26, 1.96 } . One can notice that for larger model sizes there are less branches activated compared to HydraRes, where the number of activated branches is always \(K=4\) or \(K=5\) (for ImageNet). To conclude, the gater will select less branches as its performance increases.

Table 3 Comparison of HydraRes and DynK-Hydra on ImageNet32 data set, in terms of both accuracy and inference flops

Discussion

According to Table 1 (on CIFAR-100—see blue highlights) at around 74% accuracy, the proposed DynK-Hydra has 139M flops, while HydraRes [5] has 378M flops (2.7 times improvement for DynK-Hydra). ResNet does not reach this accuracy level for its largest architecture (73.56%) with 767M flops (approx. 5.5 times less efficient on the inference time). Similar trends are observed on Food-101 data set (red highlights), CUB-200 (yellow highlights), ImageNet32, and for the DenseNet-like architectures.

Comparing entries in Table 1 is done by selecting a column (data set) and choosing either a base accuracy that could be roughly identified in all three types of models (on the same column) and compare the corresponding accuracy, or viceversa. In our blue example, the base accuracy is chosen at roughly 74% (values in boxes on the blue column: accuracy) and corresponding computation costs are compared to infer the best performing method. The same is illustrated in the red and yellow columns for the the other two data sets. The accuracy vs the number of flops is illustrated in Fig. 4.

Fig. 4
figure 4

Accuracy per computing units (flops) plot for CIFAR-100. Comparison of our DynK-Hydra to the State-of-the-Art Resnet [10] (left image) and HydraRes [5] (right image)

Table 4 Accuracy for multiple architectures depth on CIFAR-100 data set
Fig. 5
figure 5

Loss values on a synthetically created model obtained as a weighted average between a random pair of branches, using different interpolation factors. Example computed on CIFAR-100

Table 5 Accuracy of DynK-Hydra-d2 on CIFAR-100 using different subtask partitioning strategies
Table 6 Trade-off between the model accuracy and inference Flops by varying the threshold \(\tau \) applied on the gater prediction and the average number of branches activated \(\lambda \) on Food-101 data set

To prove that DynK-Hydra branches are located in different local minima, from a loss landscape perspective, and specialize on different sub-tasks, we performed an experiment similar to [31] and we applied it on the CIFAR-100 data set. To prove that any two branches (\(br_{task_A}\) and \(br_{task_B}\)) are not situated in the same local minima, we create a new model and initialize its weights with a linear interpolation (with a factor \(\alpha \in [0,1]\)) of the parameters of \(br_{task_A}\) and \(br_{task_B}\). When \(\alpha \) is 0 then \(br_{\alpha }\) = \(br_{task_A}\), and when it is 1 then \(br_{\alpha }\) = \(br_{task_B}\). Next, we test this new model on a data set comprising all the images assigned to \(task_A\) and \(task_B\) (namely, \(data set_{AB}\)). \(Task_A\) and \(task_B\) are the two partitions out of the \(n_b=20\) assigned to the selected branches. If the two branches are located in the same local minimum, the loss of the new model \(br_{\alpha }\) should not be adversely impacted.

Our results show a low value of the loss when \(\alpha \) is close to 0 or 1, and a significantly greater loss everywhere else. This means that any two branches have a “hill” between them in the loss landscape. We illustrated these findings in Fig. 5.

In addition, we performed the same analysis on the accuracy of the synthetically created model \(br_{\alpha }\). We noticed that the accuracies at the extreme values of \(\alpha \) (close to 0 or close to 1) are almost reaching 50%, meaning that each branch is performing well only on its part of the data set.

Ablation studies

The two main contributions of this work are (1) the two stage selection of the branches involved in the final prediction, and (2) the individual branch specialization loss \(L_{branch}\) described in Sect. 3.2.

Table 4 illustrates the improvements induced by each of the proposed techniques over the vanilla HydraRes [5] architecture. In this table, the lines report the accuracy of a different model. HydraRes [5] is the baseline architecture. DynK refers to a model in which we only apply the dynamic selection of the branches involved in the final prediction, \(HydraRes-SP\) is a HydraRes [5] architecture for which we apply the individual branch specialization loss \(L_{branch}\), and DynK-Hydra is the model trained with both proposed strategies. The columns of the table refer to the depth of the model. It can be noticed that using both of the proposed strategies we always obtain an improvement over the original HydraRes [5] architecture. However, DynK and DynK-Hydra models also bring up to two times reduction in the inference time.

As we explain in Sect. 3.2, we partition the input spaces into coarse labels to assign each branch a (balanced) subset of classes. Similar to [5] we call this strategy CLUSTER-Balance, and we compare this strategy with Random balanced splitting and human Semantic coarse labeling (as provided by the CIFAR-100 data set). We present the benefits of CLUSTER-Balance in Table 5.

Table 6 shows the impact of the threshold \(\tau \) applied on the gater predictions; higher values for \(\tau \) would activate more branches (use them in the ensemble final prediction), hence a higher accuracy, but also a higher computational time.

Conclusions and future work

DynK-Hydra is targeted at reducing the inference time of classification tasks with medium to large number of classes while preserving the overall accuracy. We show improvements in the inference time on the order of 2-5.5 times with superior classification accuracy compared to baseline ResNet networks and 2.8 times improvement of inference time and a marginal accuracy improvement of 1.2% against HydraRes [5]. We apply dynamic sparse execution compared to HydraRes and show that significant reduction of inference time is still possible. Although DynK-Hydra enables more work-efficient inference than fully static or semi-static architectures, this comes at the cost of longer training times and a slightly more complex training process. The projective transform performed by a pre-trained state of the art convolutional neural network (in our experiments EfficientNet-B7) allows automatic clustering of the data set and automatic generation of class clusters, used in training the architecture branches. As this process involves only the training phase, we consider it justifies the inference time gains. Exploring dynamic training with small batches, where each sample may execute different branches is a challenging task that can be approached in future research.