1. Introduction
Stroke remains a leading cause of long-term disability in adults worldwide, with a significant proportion of survivors experiencing impaired balance function, which severely compromises their quality of life and independence [
1]. Personalized rehabilitation prognosis assessment and treatment path selection are crucial for optimizing patient recovery outcomes. Traditional rehabilitation models often rely on clinicians’ empirical judgment and are constrained by limited resources, making it challenging to meet the individualized needs of a large patient population [
2]. This challenge is further compounded by the increasing adoption of TeleRehabilitation (TR), raising a critical clinical question: how can we accurately predict a patient’s response to TR versus conventional in-person rehabilitation (CR) and provide optimal treatment allocation recommendations accordingly?
Figure 1.
An overview of the motivation, challenges, and conceptual differences in stroke rehabilitation. The left panel highlights stroke-related disabilities, emphasizing balance problems and the importance of recovery for quality of life. The middle panel illustrates the variability in post-stroke recovery outcomes and the limitations of traditional rehabilitation models, leading to a critical need for personalized approaches. The right panel contrasts Conventional Rehabilitation (CR) with Telerehabilitation (TR), outlining their respective operational models, benefits, and potential outcomes.
Figure 1.
An overview of the motivation, challenges, and conceptual differences in stroke rehabilitation. The left panel highlights stroke-related disabilities, emphasizing balance problems and the importance of recovery for quality of life. The middle panel illustrates the variability in post-stroke recovery outcomes and the limitations of traditional rehabilitation models, leading to a critical need for personalized approaches. The right panel contrasts Conventional Rehabilitation (CR) with Telerehabilitation (TR), outlining their respective operational models, benefits, and potential outcomes.
While existing research has made progress in stroke prognosis prediction, most studies are often limited to single-modality data, such as clinical tabular information, or fail to fully leverage dynamic, time-series data captured during actual training sessions [? ]. Furthermore, personalized prediction and recommendation for patient responses under different treatment modalities (TR vs. CR) remains an underdeveloped area. The absence of a robust framework that can integrate diverse multimodal and temporal patient data to simultaneously predict functional recovery and recommend optimal treatment strategies represents a significant gap in current stroke rehabilitation research. This study addresses these challenges by focusing on predicting the improvement in Berg Balance Scale (BBS) scores after 8 weeks of rehabilitation for stroke patients and assessing their likelihood of responding to either TR or CR, ultimately providing personalized treatment allocation advice.
To address these limitations, we propose Causal-MMFNet (Causal Multimodal Fusion Network), a novel deep learning framework designed to leverage multimodal time-series data for predicting stroke patient balance function recovery and enabling individualized treatment recommendations. Causal-MMFNet introduces two core innovations: a dynamic cross-modal attention fusion mechanism and an Individual Treatment Effect (ITE) estimation module. The dynamic attention mechanism adaptively weights and fuses features from various modalities, generating a richer, context-aware patient representation. Crucially, the ITE estimation module explicitly models counterfactual outcomes for each patient under both TR and CR, allowing us to quantify the differential benefit of each treatment and recommend the option with higher expected gain. The framework also incorporates multimodal encoders for IMU, video keypoints, training logs, and clinical tabular data, combined with robust loss functions and causal consistency regularization to ensure accurate and reliable predictions. Furthermore, we employ Monte Carlo Dropout for uncertainty estimation, providing clinicians with confidence intervals for personalized decision-making.
Our experimental evaluation was conducted on the StrokeBalance-Sim dataset, a simulated yet comprehensive dataset (n=1,216) integrating various multimodal patient data, including clinical tabular data (36 dimensions), wearable IMU time-series data, home-based training logs, and Kinect/mobile video-based keypoint time-series data. We define two primary tasks: Task A (regression) focuses on predicting 8-week BBS score improvement (BBS), while Task B (classification/recommendation) aims to identify "responders" (BBS ≥ 5 points) and recommend the optimal treatment (TR or CR) based on estimated response probabilities. Performance was assessed using standard metrics such as Mean Absolute Error (MAE), Root Mean Squared Error (RMSE), R-squared () for regression, and Area Under the Receiver Operating Characteristic Curve (AUROC), Area Under the Precision-Recall Curve (AUPRC), and Expected Calibration Error (ECE) for classification. Our proposed Causal-MMFNet consistently achieved superior performance across all evaluation metrics, significantly outperforming existing baseline methods and state-of-the-art multimodal time-series learning frameworks like MM-TRNet, validating its efficacy in both prediction and personalized treatment recommendation.
Our key contributions are summarized as follows:
We introduce Causal-MMFNet, a novel end-to-end deep learning framework that effectively integrates diverse multimodal and temporal patient data for simultaneous balance function recovery prediction and personalized treatment allocation in stroke rehabilitation.
We propose a sophisticated dynamic cross-modal attention fusion mechanism and an innovative Individual Treatment Effect (ITE) estimation module, coupled with causal consistency regularization, enabling robust counterfactual outcome prediction and evidence-based treatment recommendations.
We demonstrate that Causal-MMFNet achieves state-of-the-art performance on the StrokeBalance-Sim dataset, offering superior accuracy and reliability in predicting rehabilitation outcomes and guiding personalized treatment decisions compared to existing methods.
3. Method
We introduce
Causal-MMFNet (Causal Multimodal Fusion Network), a novel deep learning framework specifically designed to integrate diverse multimodal time-series data for predicting balance function recovery in stroke patients and providing individualized treatment recommendations. The core innovations of Causal-MMFNet lie in its
dynamic cross-modal attention fusion mechanism and an
Individual Treatment Effect (ITE) estimation module, which together enable a more precise understanding of complex patient states and robust prediction of potential outcomes under different rehabilitation strategies. These innovations are crucial for moving beyond population-level averages to truly personalized medicine in rehabilitation. An overview of the Causal-MMFNet architecture is presented in
Figure 2.
3.1. Multimodal Encoders
Causal-MMFNet begins by processing various modalities of patient data through specialized encoders, each tailored to capture salient features from its respective data type. Let denote the input data for modality m. Each encoder transforms into a high-dimensional semantic embedding , where D is the embedding dimension. These embeddings serve as context-rich representations for subsequent fusion.
3.1.1. IMU Time-Series Encoder
For wearable Inertial Measurement Unit (IMU) time-series data, which captures fine-grained movement patterns and kinematic information, we employ a 1D convolutional layer followed by a residual Temporal Convolutional Network (TCN). The initial 1D convolutional layer acts as a feature extractor and temporal dimensionality reduction step, transforming raw sensor signals into a more abstract representation. The TCN, leveraging dilated convolutions and residual connections, is particularly effective in modeling both local and long-range temporal dependencies within the IMU signals without increasing memory footprint for longer sequences, making it suitable for capturing complex movement dynamics.
where
represents the raw IMU sensor data streams, with
being the time steps and
the number of sensor channels.
3.1.2. Video Keypoint Time-Series Encoder
Video-based keypoint time-series data, reflecting intricate posture changes, limb coordination, and gait parameters, is processed using a multi-layer Transformer network. The self-attention mechanism intrinsic to the Transformer architecture is uniquely suited for capturing complex spatial relationships between different body keypoints at each timestep, as well as long-range temporal dependencies across the sequence of keypoints. This allows the encoder to learn how specific movements evolve and interact over time, forming a comprehensive kinematic signature.
where
is the sequence of 2D or 3D keypoint coordinates over time, with
being the number of detected body keypoints and
the coordinate dimensions.
3.1.3. Training Log Data Encoder
Patient training log data, which provides insights into adherence, effort, and progressive overload, is encoded using a Gated Recurrent Unit (GRU) network. GRUs are a variant of Recurrent Neural Networks (RNNs) specifically designed to mitigate the vanishing gradient problem, making them capable of learning long-term sequential dependencies in potentially sparse or irregularly sampled log entries. This enables the GRU to generate abstract features indicative of patient engagement, progress, and consistency in their rehabilitation exercises.
where
represents the time-series of training activity logs, with
being the number of log entries and
the feature dimension of each entry.
3.1.4. Clinical Tabular Data Encoder
Baseline clinical tabular data, encompassing static patient demographics, medical history, initial clinical assessments, and other fixed attributes, is processed through a two-layer Multilayer Perceptron (MLP) with LayerNormalization. LayerNormalization is applied to standardize feature distributions, improving training stability and convergence for tabular inputs. The MLP then applies non-linear transformations to these features, generating a high-dimensional semantic representation that robustly encodes the patient’s initial clinical state and static characteristics.
where
is the 36-dimensional vector of static clinical features.
3.2. Dynamic Cross-Modal Attention Fusion
Unlike simpler fusion strategies such as concatenation or fixed-weight averaging, Causal-MMFNet incorporates a
dynamic cross-modal attention fusion mechanism. This module adaptively weighs and combines the semantic embeddings (
) from each modality encoder. The attention mechanism calculates modality-specific scalar weights based on the current context provided by each modality’s embedding, allowing the model to dynamically focus on the most informative modalities for a given patient’s state and the specific prediction task. This is crucial as the relevance of different data types may vary across patients or stages of recovery. Let
be the embedding for modality
m. The attention score
for each modality is first computed by passing its embedding through a dedicated single-layer Multilayer Perceptron, followed by a non-linear activation (e.g., ReLU). These scores are then normalized across all modalities using a softmax function to obtain the attention weights
:
The global patient representation
is then generated by a weighted sum of the modality embeddings, where each embedding
is scaled by its corresponding attention weight
:
where · denotes element-wise scaling of the vector
by the scalar weight
. This dynamic fusion generates a rich, context-aware representation that captures the intricate interplay between different data sources, adapting to their varying predictive power.
3.3. Individual Treatment Effect (ITE) Estimation Module
The Individual Treatment Effect (ITE) estimation module is a cornerstone of Causal-MMFNet, designed to explicitly estimate counterfactual outcomes for each patient under different treatment scenarios (TeleRehabilitation (TR) or Conventional Rehabilitation (CR)). This capability is paramount for personalized treatment recommendations. This module takes the fused global patient representation as input and branches into two distinct prediction heads: one for TeleRehabilitation (TR) and one for Conventional Rehabilitation (CR). Each prediction head is implemented as a multi-layer perceptron. Each head, , where , simultaneously outputs two crucial predictions for its respective treatment:
- 1.
The predicted 8-week Berg Balance Scale (BBS) score improvement, denoted as . This is a regression output.
- 2.
The predicted probability of being a "responder," defined as achieving a , denoted as . This is a classification output.
The outputs of the ITE estimation module can be formulated as:
By modeling these two potential outcomes for each patient, we can calculate the individualized treatment effect (ITE), defined as the difference in expected BBS improvement between the two treatments:
Based on this estimated ITE, the framework recommends the treatment option (TR or CR) that is predicted to yield a greater benefit for the individual patient, thus optimizing rehabilitation strategy.
3.4. Loss Function and Causal Consistency Regularization
Causal-MMFNet is optimized using a comprehensive loss function that addresses both the regression task of predicting BBS improvement and the classification task of identifying responders, augmented by a novel causal consistency regularization term. This multi-objective approach ensures robust and causally-aware learning.
3.4.1. Regression Task Loss
For the
prediction regression task, we employ a combination of L1 loss and Smooth L1 loss. L1 loss (Mean Absolute Error) encourages sparsity and robustness to outliers by penalizing the absolute difference between predictions and true values. Smooth L1 loss, also known as Huber loss, combines the advantages of L1 and L2 losses by being less sensitive to outliers than L2 loss and providing a smoother gradient near the target than L1 loss, which is beneficial for stable training. The regression loss is computed only for the treatment arm
c that the patient actually received.
where
is the observed BBS improvement for the assigned treatment
c, and
are hyperparameters weighting the contribution of each loss component.
3.4.2. Classification Task Loss
For the "responder" classification task, we utilize Focal Loss to address potential class imbalance issues inherent in clinical datasets, where the number of responders might be significantly lower than non-responders. Focal Loss down-weights easy examples and focuses training on hard, misclassified examples. Additionally, an Expected Calibration Error (ECE)-style calibration term is included to ensure that the predicted probabilities
are reliable and well-calibrated, meaning they accurately reflect the true likelihood of response. The classification loss is also applied only for the observed treatment arm
c.
where
is the true binary responder status for the assigned treatment
c.
3.4.3. Causal Consistency Regularization
To enhance the robustness of the ITE estimation module and encourage it to learn more reliable causal relationships, we introduce a
causal consistency regularization term. This term operates on a single patient’s representation
and encourages the two treatment heads,
and
, to exhibit consistent functional behaviors. Specifically, it enforces structural similarities or relationships between the counterfactual predictions generated by
and
, thereby mitigating potential biases that might arise from covariate shifts or distributional differences between the observed treatment groups. The precise form of
is designed to penalize inconsistencies in how the model maps
to the outcomes under different treatments.
The total loss function for Causal-MMFNet is a weighted sum of these components:
where
is a weighting hyperparameter for the causal consistency term, controlling its influence on the overall optimization process.
3.5. Uncertainty Estimation
To provide clinicians with an additional layer of confidence for personalized decision-making, Causal-MMFNet incorporates uncertainty estimation. We employ Monte Carlo Dropout (MC Dropout) during inference. By activating dropout layers with a rate during testing and performing 20 stochastic forward passes for each patient, we generate an ensemble of predictions for both and . The variance or interquartile range across these 20 predictions serves as an estimate of the model’s predictive uncertainty. This provides valuable insight into the model’s confidence in its recommendations, enabling clinicians to assess the reliability of personalized treatment effects before implementation.
4. Experiments
4.1. Dataset and Task Definition
The effectiveness of Causal-MMFNet was rigorously evaluated on the StrokeBalance-Sim dataset, a comprehensive simulated dataset comprising stroke patients. This dataset integrates diverse multimodal patient data, including 36-dimensional clinical tabular data, wearable IMU time-series data, home-based physical training logs, and Kinect/mobile video-based keypoint time-series data. The dataset was meticulously partitioned at the subject level into training, validation, and test sets with a ratio of 70%, 10%, and 20% respectively, ensuring no data leakage between splits. The primary follow-up endpoint for recovery assessment was the Berg Balance Scale (BBS) score measured at 8 weeks post-rehabilitation. For classification purposes, a patient was defined as a "responder" if they achieved a BBS (BBS score improvement) of 5 points or more by the 8-week follow-up. Our experimental evaluation focused on two primary tasks: (1) Task A (Regression), which involves predicting the 8-week BBS score improvement (BBS) given the patient’s baseline clinical information and time-series data from the first two weeks of their rehabilitation training; and (2) Task B (Classification and Recommendation), which aims to determine whether a patient is a "responder" (i.e., BBS ≥ 5 points) and to estimate their probability of response under both TeleRehabilitation (TR) and Conventional Rehabilitation (CR) treatment modalities. Based on these estimated probabilities and the individualized treatment effect, the model provides a personalized treatment allocation recommendation.
4.2. Evaluation Metrics
To comprehensively assess model performance across both tasks, a suite of standard metrics was employed. For Task A (Regression), we report the Mean Absolute Error (MAE), Root Mean Squared Error (RMSE), and the Coefficient of Determination (). Lower MAE and RMSE values indicate better predictive accuracy, while a higher signifies a greater proportion of variance in the outcome explained by the model. For Task B (Classification), the Area Under the Receiver Operating Characteristic Curve (AUROC), Area Under the Precision-Recall Curve (AUPRC), and Expected Calibration Error (ECE) were utilized. Higher AUROC and AUPRC values indicate better discrimination and recall performance, respectively, especially crucial in imbalanced datasets. A lower ECE value demonstrates better calibration, ensuring that predicted probabilities accurately reflect true likelihoods. Finally, to evaluate the quality of Treatment Allocation, we assessed the average actual BBS improvement among the Top-Q% of patients for whom the model recommended a specific treatment, demonstrating the real-world utility of the personalized recommendations. For our analysis, we specifically report results for the Top-25% of patients recommended by the model.
4.3. Baseline Methods
To benchmark the performance of Causal-MMFNet, we compared it against a range of established machine learning and deep learning methods, representing different approaches to handling patient data for prediction tasks. These include: Linear Regression, a fundamental statistical model used as a basic benchmark; Random Forest, an ensemble learning method known for its robustness; and XGBoost, a highly efficient gradient boosting framework popular for tabular data. For time-series data, we compared against sequential models such as LSTM (Long Short-Term Memory), a type of recurrent neural network adept at learning from sequences; TCN (Temporal Convolutional Network), a convolutional architecture well-suited for sequence modeling; and the Transformer network, which leverages self-attention for capturing long-range dependencies. Finally, we included MM-TRNet, a state-of-the-art multimodal time-series learning framework designed for similar clinical prediction tasks, serving as a strong deep learning baseline that integrates various data types. For all baselines, multimodal inputs were processed either by concatenating flattened features (for non-sequential models) or by using specialized architectures (e.g., separate encoders for each modality, followed by a fusion layer, where applicable, for deep learning models) to allow for fair comparison within their architectural capabilities.
4.4. Implementation Details
The proposed Causal-MMFNet framework was implemented using the PyTorch deep learning library. Optimization was performed using the AdamW optimizer, configured with an initial learning rate of and a weight decay of . Models were trained with a batch size of 64 for 80 epochs, incorporating an early stopping strategy with a patience of 10 epochs based on performance on the validation set. To enhance model generalization and robustness, extensive data augmentation techniques were applied: IMU and video keypoint (Pose) data underwent jittering, time warping, and keypoint loss; training log data was augmented by injecting noise. Hyperparameter tuning and model selection were conducted using 5-fold cross-validation on the training set, and the final performance of the chosen model was evaluated on the independent test set.
4.5. Performance Comparison
Table 1 presents a comprehensive comparison of Causal-MMFNet against the selected baseline methods on the
StrokeBalance-Sim test set. The results clearly demonstrate the superior performance of our proposed Causal-MMFNet across all evaluation metrics for both the regression (MAE, RMSE, R
2) and classification (AUROC, AUPRC, ECE) tasks.
Specifically, Causal-MMFNet achieved the lowest MAE of 2.33±0.05 and RMSE of 3.35±0.06, along with the highest R2 of 0.70±0.02, indicating highly accurate predictions of BBS. For the classification task, Causal-MMFNet obtained the highest AUROC of 0.840±0.011 and AUPRC of 0.720±0.014, showcasing its excellent discriminatory power in identifying responders. Furthermore, its ECE of 0.030±0.003 was the lowest, confirming the high reliability and calibration of its predicted response probabilities.
Comparing Causal-MMFNet with the previous state-of-the-art multimodal time-series learning framework, MM-TRNet, our method consistently yielded improved performance. For instance, Causal-MMFNet reduced MAE by 2.1% (from 2.38 to 2.33) and RMSE by 1.8% (from 3.41 to 3.35), while improving AUROC by 0.7% (from 0.834 to 0.840) and AUPRC by 1.0% (from 0.713 to 0.720). This highlights the effectiveness of Causal-MMFNet’s novel components, particularly the dynamic cross-modal attention fusion and the individualized treatment effect estimation module, in leveraging complex multimodal time-series data for more precise and reliable predictions. Simple statistical models and even advanced deep learning models without explicit causal modeling or dynamic fusion struggled to capture the intricate relationships in the data as effectively.
4.6. Ablation Study
To validate the individual contributions of the key architectural innovations within Causal-MMFNet, we conducted an ablation study. Specifically, we investigated the impact of the dynamic cross-modal attention fusion mechanism, the Individual Treatment Effect (ITE) estimation module, and the causal consistency regularization term. The results, summarized in
Table 2, highlight the importance of each component for the overall performance.
Removing the dynamic cross-modal attention fusion mechanism (replaced with simple concatenation of modality embeddings) led to a noticeable drop in performance across all metrics. For instance, MAE increased to 2.45, and AUROC dropped to 0.821. This demonstrates that adaptively weighting and fusing information from different modalities is crucial for building a rich and context-aware global patient representation, allowing the model to focus on the most relevant data for each patient’s unique profile.
Disabling the ITE estimation module (by replacing the dual prediction heads with a single head that predicts conditional outcomes based on the assigned treatment, without explicit counterfactual modeling) also resulted in degraded performance. While the impact on general prediction metrics (MAE, AUROC) was slightly less pronounced than removing dynamic attention, it specifically hampers the model’s ability to precisely estimate the differential benefits of TR versus CR, which is critical for personalized recommendations. The slight drops across all metrics suggest that explicitly modeling counterfactuals helps in learning more robust and generalizable representations for outcome prediction.
Finally, omitting the causal consistency regularization term led to a minor but consistent decrease in performance (e.g., MAE increased to 2.37, ECE to 0.032). This suggests that encouraging consistent functional behaviors between the treatment heads for similar patient features, even when only one outcome is observed, helps in learning more robust causal relationships and mitigating biases stemming from treatment group distribution differences. The improvement in ECE when this regularization is present particularly highlights its role in ensuring reliable probability predictions.
Collectively, these ablation results confirm that each proposed component of Causal-MMFNet plays a vital role in its superior performance, enabling more accurate, robust, and causally-aware predictions and recommendations.
4.7. Treatment Allocation Quality
Beyond predictive accuracy, the ultimate goal of Causal-MMFNet is to provide effective personalized treatment recommendations. To evaluate the real-world utility of our framework, we assessed the quality of its treatment allocation by analyzing the actual BBS improvement for patients recommended by the model. Specifically, we examined the average BBS achieved by the Top-25% of patients for whom Causal-MMFNet predicted the highest Individual Treatment Effect (ITE), meaning these patients were predicted to benefit most from their assigned treatment (TR or CR) compared to the alternative. This metric provides a crucial insight into how well the model can identify patients who will truly thrive under a specific personalized regimen.
As shown in
Figure 3, Causal-MMFNet’s personalized recommendations led to a significantly higher average actual
BBS for the Top-25% recommended patients compared to conventional approaches. For example, patients recommended by Causal-MMFNet to receive TR (and who were in the Top-25% predicted ITE) achieved an average
BBS of
7.2±0.03. Similarly, those recommended for CR achieved an average
BBS of
6.8±0.03. These figures are substantially higher than what would be expected from a random allocation (e.g., an average of 4.5±0.2, representing the overall average improvement across all patients and treatments without personalized recommendations), and also surpass the average improvements observed in the general TR or CR groups. The "Oracle" represents the theoretical maximum average improvement if one could perfectly identify the best treatment for the Top-25% patients beforehand. While Causal-MMFNet does not reach this oracle performance, it significantly closes the gap, demonstrating its practical value in guiding clinical decisions for personalized rehabilitation. This validates the effectiveness of the ITE estimation module and the overall framework in identifying individuals who will respond best to specific treatment pathways, thus optimizing rehabilitation outcomes.
4.8. Modality Contribution Analysis
To further dissect the impact of each data source, we performed a modality contribution analysis by systematically training variants of Causal-MMFNet where one or more input modalities were excluded. This experiment highlights the unique and synergistic value that each data stream brings to the overall predictive model. The baseline for comparison is the full Causal-MMFNet model. The results in
Figure 4 illustrate the performance degradation when specific modalities are withheld.
The results show that all modalities contribute positively to the model’s performance, as removing any single modality leads to a performance drop across all metrics. The largest drop in performance is observed when Clinical Data is removed (MAE increases from 2.33 to 2.60, AUROC drops from 0.840 to 0.798), suggesting that static patient characteristics and initial assessments provide a strong foundation for prediction. Video Keypoints also show a significant impact (MAE to 2.55), indicating the importance of detailed kinematic information captured by posture and gait. While IMU and Training Logs show slightly smaller individual impacts, their collective presence is essential for the superior performance of the full model. This is reinforced by the "Clinical Only" and "Time-Series Only" rows, demonstrating that relying solely on one category of data results in substantially worse performance than the integrated multimodal approach. This analysis confirms the strong multimodal synergy captured by Causal-MMFNet’s architecture, especially its dynamic cross-modal attention fusion, which effectively integrates and leverages information from diverse data sources.
4.9. Analysis of Dynamic Attention Weights
The dynamic cross-modal attention fusion mechanism is designed to adaptively weigh the contribution of each modality based on the patient’s context. To understand how this mechanism operates, we analyzed the average attention weights assigned to each modality across the test set. Furthermore, we investigated how these weights might differ between patients predicted to be "responders" (
) versus "non-responders" (
), and between patients for whom TeleRehabilitation (TR) or Conventional Rehabilitation (CR) was recommended. This analysis, presented in
Table 3, provides insights into which modalities the model prioritizes under different conditions.
The "Overall Average" row indicates that Clinical Data generally receives the highest attention weight (0.30±0.04), followed closely by Video Keypoints (0.26±0.05). This suggests that baseline patient characteristics and visual movement patterns are consistently deemed highly informative for balance recovery prediction.
Interestingly, for "Predicted Responders," the attention shifts slightly, with IMU and Video Keypoints receiving slightly higher average weights (0.25 and 0.27 respectively), while "Clinical Data" weight decreases. This implies that for patients showing greater potential for recovery, the model might place more emphasis on dynamic movement patterns and rehabilitation effort captured by time-series data. Conversely, for "Predicted Non-Responders," the Clinical Data weight increases substantially to 0.34±0.04, suggesting that static, baseline health indicators become more dominant in predicting limited improvement.
Similarly, when the model recommends "TR Patients," the Video Keypoints modality receives the highest attention (0.28±0.05), potentially reflecting the importance of visual feedback and observable progress in remote rehabilitation settings. For "CR Patients," Clinical Data again dominates (0.33±0.05), possibly indicating that patients who benefit more from conventional, in-person therapy might have more complex or established clinical profiles that influence treatment choice. This dynamic weighting mechanism validates the adaptive nature of Causal-MMFNet, allowing it to judiciously combine multimodal information based on the specific predictive context and patient characteristics.
4.10. Uncertainty Estimation Reliability
The integration of Monte Carlo Dropout (MC Dropout) for uncertainty estimation is crucial for providing clinicians with a measure of confidence alongside treatment recommendations. To evaluate the reliability of these uncertainty estimates, we assessed how well the predicted uncertainty (quantified as the standard deviation of MC Dropout predictions) correlates with the actual magnitude of prediction errors. A well-calibrated uncertainty mechanism should ideally assign higher uncertainty to predictions that are further from the true outcome.
Table 4 presents the average prediction errors (MAE) across different quantiles of predicted uncertainty. We divided the test set predictions into five bins based on their predicted uncertainty (from lowest to highest).
The results clearly demonstrate a positive correlation between the predicted uncertainty and the actual prediction error. Predictions falling into the lowest 20% uncertainty quantile exhibit an average MAE of 1.88±0.04, indicating high accuracy where the model is most confident. Conversely, for predictions in the highest 20% uncertainty quantile, the average MAE rises significantly to 2.89±0.07. This trend indicates that when Causal-MMFNet expresses higher uncertainty in its predictions, those predictions indeed tend to have larger errors. This validates the reliability of MC Dropout as an uncertainty quantification method within our framework, providing clinicians with a meaningful signal to gauge the trustworthiness of personalized recommendations and to identify cases that may warrant further clinical scrutiny or data collection.