1 Background

In the last decade, with the rapid growth of electronic health records (EHRs), data-driven studies, such as drug repurposing and medical event prediction, have become more and more popular in the medical domain. An EHR is composed of a sequence of a patient’s visits in chronological order, each of which contains various medical information, such as demographics, vital signs, diagnoses, medications, procedures, laboratory test results, etc. A medical event prediction task is to predict some types of medical events, including diseases, prescriptions, outcomes, etc., using other types of medical information or medical history. For example, next-period prescription prediction is to predict all medications of a patient in the next time using his/her medical history. The critical challenge of medical event prediction driven by EHRs is how to represent patient longitudinal medical data accurately, also known as patient representation.

A large number of methods have been proposed to predict medical events, As medical data of each patient is time series data, the typical time series analysis methods such as machine learning methods based on manually-crafted features [1] and Autoregressive Integrated Moving Average (ARIMA) models applied for medical event prediction [2]. The manually-crafted features used in traditional machine learning methods are usually task-specific, and do not cover complex relations related to different types of medical information and time. In recent years, with the development of deep learning methods, a number of them have been gradually used for medical event prediction and have achieved good results.

The basic deep learning methods such as RNN [3] and Convolution Neural Networks (CNN) [4] have shown great potential for patient representation. However they still have limitations in heterogeneous information fusion and temporal information utilization.

To deal with these limitations mentioned above, researchers have attempted to extend the basic deep learning methods. For example, Jin et al. proposed a new Long Short-Term Memory (LSTM, a variant of RNN) learning framework to model inter-correlations of different types of medical information [5]. They developed three variants of LSTM based on demographics, laboratory test results and history prescriptions to predict next-period prescriptions. Among these variants, the decomposed LSTM that used LSTM to model history prescription sequence and integrated laboratory test results into each cell achieved the best results. An et al. proposed an attention-based LSTM for high-risk prediction [6]. The model first adopted attention-based LSTM to individually represent each type of medical information, including diagnosis sequence, laboratory sequence, and their combination sequence, and then concatenate these three representations together for prediction. All these methods aim to represent patient status sequences, but without considering complex interactions among different types of medical information, that is, temporal medical event graphs. GNN is a potential method to represent temporal medical event graphs.

In this study, we focus on next-period prescription prediction, and propose a novel hybrid method of RNN and GNN (called RGNN) for this task. In RGNN, RNN and GNN are first individually used to represent patient longitudinal medical data from two views, and then they are combined organically. The decomposed LSTM, a state-of-the-art method for next-period prescription prediction, is used to represent patient status sequences, and GNN is used to represent temporal event graphs.

In the case of GNN, we present two methods to represent medical event graphs. Experiments conducted on the public MIMIC-III ICU data show that the proposed method is effective for next-period prescription prediction and RNN and GNN are mutually complementary.

In summary, our main contributions are:

  • We propose a novel hybrid method of RNN and GNN, called RGNN, to represent patient longitudinal medical data from two views and apply it to next-period prescription prediction. In RGNN, RNN is used for patient status sequence representation, and GNN for temporal medical event graph representation.

  • We present two methods to build a temporal medical event graphs and represent them by GNN. We first build the graphs according to co-occurrences of different types of medical events at two neighbor time points, where the nodes are medical events and the edges are co-occurrences relations, and then further add time intervals between medical events as edge weights.

  • We evaluate the performance of RGNN on the public MIMIC-III ICU data. Experiment results demonstrate that RGNN can achieve promising performance on next-period prescription prediction, and RNN and GNN are mutually complementary.

2 Related work

2.1 Medical event prediction

In the medical domain, medical event prediction is a promising research topic. The main task of medical event prediction is to predict future medical events including risk of diseases [6], prescriptions [5], mortality rate [7], hospital readmission [8], length of stay in hospital [9], postoperative complications [10], survival time [11] and so on. In this paper, we mainly focus on next-period prescription prediction.

As patient medical data is typical time series data, almost all time series data analysis methods can be deployed for medical event prediction. They may fall into the following three categories: (1) statistic analysis methods such as Cox proportional hazards model [12] and hierarchical Association Rule Model (HARM) [13]; (2) statistic machine learning methods based on manually-crafted features such as artificial neural network, decision tree, logistic regression, Support Vector Machines (SVM) [14, 15]; (3) deep learning methods such as auto-encode model [16], CNN [11], RNN [17], LSTM and Bi-LSTM [18]. For example, in the early studies, McCormick et al.’s proposed a Hierarchical Association Rule Model (HARM) to predict disease risk from medical data using association analysis and Bayesian estimation [13]. Deken et al. are investigated ANN, decision tree, and logistic regression for breast cancer survival time prediction [19]. Choi et al. are presented an RNN-based system, i.e., doctor AI, to predict multiple medical events in the next time [20]. Baytas et al. proposed a time-aware LSTM network, called T-LSTM, for patient subtyping [21]. Among the three categories, deep learning methods attract more and more attention from researchers in recent years due to their excellent ability to model complex non-linear relationships within data and their remarkable results [22]. In the case of next-period prescription prediction, one of the latest studies is Jin et al.’s work [5]. In this work, they proposed a variant LSTM, called decomposed LSTM, that can fuse different types of medical information effectively.

2.2 Graph neural network

Graph neural network (GNN) is a kind of deep neural network powerful for complex graphs [23]. Several methods are recently proposed to compute representations of nodes, edges, and graphs [24,25,26]. Among them, Graph Convolutional Network (GCN) that computes the representation of a node recursively from its neighbors is the most common one [27] and is widely applied to many domains such as natural language processing and knowledge graph representation. In the medical domain, GCN starts to be applied to many tasks. For example, Choi et al.’s deployed GCN to learn medical concept representations from the graph of medical ontology knowledge [28]. Ma et al.’s recognized the drug–drug interaction (DDI) problem as a graph classification problem and solve it by GCN [29]. Besides GCN, some other GNNs also have been proposed recently such as GAMENet [30] and Decagon [31]. GAMENet is a Graph Augmented Memory Network designed to integrate the DDI knowledge graph for the personalized recommendation of medication combination. Decagon is a multi-modal GNN for drug side effect prediction. Decagon is a multi-modal GNN for drug side effect prediction.

3 Method

3.1 Task introduction

Next-period prescription prediction can be recognized as a multi-label classification problem as follows: given a set of patients P = { \(p_{1} ,p_{2} , \ldots ,p_{N}\) } with their medical histories X = { \(x^{{p_{1} }}\), \(x^{{p_{2} }}\), …, \(x^{{p_{N} }}\) } and demographics D = { \(d^{{p_{1} }}\), \(d^{{p_{2} }}\), …, \(d^{{p_{N} }}\) }, we need to predict their prescriptions \(M = \left\{ {m^{{p_{1} }} ,m^{{p_{2} }} , \ldots ,m^{{p_{N} }} } \right\}\) in the next time. For patient pk, \(x^{{p_{k} }}\) is a visit sequence that comprises sequences of different types of medical events such as diagnosis sequence, laboratory test sequence, prescription sequence, etc. In this study, following Jin et al.’s work, we only consider laboratory test sequence and prescription sequence, denoted by \(x_{m}^{{p_{k} }} = \left\{ {x_{{m_{1} }}^{{p_{k} }} ,x_{{m_{2} }}^{{p_{k} }} , \ldots ,x_{{m_{t} }}^{{p_{k} }} } \right\}\) and \(x_{l}^{{p_{k} }} = \left\{ {x_{{l_{1} }}^{{p_{k} }} ,x_{{l_{2} }}^{{p_{k} }} , \ldots ,x_{{l_{t} }}^{{p_{k} }} } \right\}\) respectively, where \(x_{{m_{i} }}^{{p_{k} }}\) and \(x_{{l_{i} }}^{{p_{k} }}\) are medications and laboratory tests at the time i, represented by a one-hot vector of medication vocabulary (denoted as ) and a one-hot vector of laboratory test item vocabulary (denoted as L). We adopt \(x_{i}^{{p_{k} }} = \left[ {x_{{m_{i} }}^{{p_{k} }} ,x_{{l_{i} }}^{{p_{k} }} } \right]\) to denote the status of pk at the time t. In the case of M, \(m^{{p_{k} }} = x_{{m_{t + 1} }}^{{p_{k} }}\). Figure 1 gives an example of medication vocabulary, where the numbers in parentheses are indices of items in the vocabularies.

Fig. 1
figure 1

Example of the medical history of a patient (pk)

3.2 Model description

The structure of our proposed hybrid method of decomposed LSTM and GNN (RGNN) is shown in Fig. 2. The medical history \(x^{{p_{k} }}\) of patient, \(p_{k}\) is represented by decomposed LSTM and GNN separately, and then the representations from the two views are combined in different ways. Subsequently, the final representation of \(x^{{p_{k} }}\) (denoted by \(h^{{p_{k} }}\)) and the demographics \(d^{{p_{k} }}\) of patient pk are concatenated, transformed by a Rectified Linear Unit (ReLU) activation function, and fed into a Sigmod function to predict medications in the next time.

Fig. 2
figure 2

Structure of the hybrid method of decomposed LSTM and GNN (RGNN)

Following the Jin et al.’s work [5], we use the following decomposed LSTM (as shown in Fig. 3) to represent the status sequence of patient \(p_{k} \left( {\left[ {x_{1}^{{p_{k} }} ,x_{2}^{{p_{k} }} , \ldots ,x_{t}^{{p_{k} }} } \right]} \right)\):

$$s_{t} = \sigma \left( {W_{s} \left[ {x_{{m_{t} }}^{{p_{k} }} ,h_{t - 1} } \right] + b_{s} } \right),$$
(1)
$$d_{t} = \sigma \left( {W_{decomp} C_{t - 1} + b_{decomp} } \right),$$
(2)
$$\widetilde{C}_{t}^{l} = d_{t} \odot { \tanh }\left( {W_{l} x_{{l_{t} }}^{{p_{k} }} + b_{l} } \right),$$
(3)
$$\widetilde{C}_{t} = { \tanh }\left( {W_{c} \left[ {x_{{m_{t} }}^{{p_{k} }} ,h_{t - 1} } \right] + b_{c} } \right),$$
(4)
$$C_{t} = f_{t} \odot \left( {C_{t - 1} + \widetilde{C}_{t}^{l} } \right) + i_{t} \odot \widetilde{C}_{t} ,$$
(5)
$$h_{t} = o_{t} \odot { \tanh }\left( {C_{t} } \right),$$
(6)

where \(\sigma \left( . \right)\), \({ \tanh }\left( . \right)\) and \(\odot\) denote the logistic function, hyperbolic tangent function, and element-wise multiplication, respectively; \(s\) denotes one of three gates (forget f, input i and output o gates); Ws are weight matrices and bs are bias vectors.

Fig. 3
figure 3

Structure of decomposed LSTM

We propose two methods to build temporal medical event graph: (1) co-occurrence graph and (2) time-aware graph. In the co-occurrence graph, medical events at the time i (\(x_{i}^{{p_{k} }}\)) are linked to medical events at time i + 1 (\(x_{i + 1}^{{p_{k} }}\)). The graph can be denoted by \(A \in \sum \times L \times T\)(T = {1, 2, …, t}) and \(A_{i} \in \sum \times L\) is the adjacency matrix regarding \(x_{i}^{{p_{k} }}\) and \(x_{i + 1}^{{p_{k} }}\). After obtaining co-occurrence graph A, we deploy GCN, a kind of GNN, to represent A as follows:

$$\widehat{A} = \widetilde{D}^{{ - \frac{1}{2}}} \left( {A + I} \right)\widetilde{D}^{{ - \frac{1}{2}}} ,$$
(7)
$$Z = MEAN\left\{ {\widehat{A} ReLU\left( {\widehat{A}XW_{1} } \right)W_{2} } \right\},$$
(8)

where \(\widetilde{D}\). is the degree matrix of A, I is the identity matrix, Ws are weight matrices, and MEAN{.} is the mean function.

Considering different time intervals between two neighbor time points, we build a time-aware graph \(\overline{A}\) by replacing \(a_{jk} = 1 \in A_{i}\) by \(\overline{{a_{jk} }} = \frac{1}{{\Delta t_{i} }} \in \overline{{A_{i} }}\), where \(\Delta t_{i}\) is the time interval between the time i and time i + 1.

We combine the representations from decomposed LSTM and GNN in the following three ways:

  1. 1.

    Concatenation mechanism:\(\left[ {h_{t} ,g} \right]\)

  2. 2.

    Attention mechanism:

    $$\left[ {\mathop \sum \limits_{1 \le i \le t} \alpha_{i} h_{t} ,h_{t} } \right] s.t.\alpha_{i} = \frac{{{ \exp }\left( {f\left( {g,h_{i} } \right)} \right)}}{{\mathop \sum \nolimits_{1 \le k \le t} { \exp }\left( {\left( {g,h_{k} } \right)} \right)}}$$
    (9)
  3. 3.

    Gate mechanism:\(\sigma \left( g \right)h_{t}\)

The final medical history representation of \(p_{k}\) (i.e.,\(h^{{p_{k} }}\)) and the demographics \(d^{{p_{k} }}\) of patient \(p_{k}\) are concatenated and transformed into \(r^{{p_{k} }}\) by a ReLU function defined as follows:

$$r^{{p_{k} }} = ReLU\left( {W_{r} \left[ {h^{{p_{k} }} ,r^{{p_{k} }} } \right] + b_{r} } \right),$$
(10)

where \(W_{r}\) is the weight matrix and \(b_{r}\) is the bias vector.

At the end of the hybrid method, the medications in the next time are predicted by a Sigmod function as follows:

$$\widetilde{x}_{m}^{{p_{k} }} = \sigma \left( {W_{m} r^{{p_{k} }} + b_{m} } \right),$$
(11)

where \(W_{m}\) is the weight matrix and \(b_{m}\) is the bias vector.

To obtain model parameters, we employ the cross-entropy as the loss function as follows:

$$LOSS\left( {\widetilde{x}_{m}^{{p_{k} }} ,x_{m}^{{p_{k} }} } \right) = \frac{1}{N}\mathop \sum \limits_{1 \le k \le N} \left( {\widetilde{x}_{m}^{{p_{k} }} \log x_{m}^{{p_{k} }} + \left( {1 - \widetilde{x}_{m}^{{p_{k} }} } \right)log\left( {1 - x_{m}^{{p_{k} }} } \right)} \right)$$
(12)

4 Experiments

4.1 Dataset

In this paper, we select the records of a group of patients that have at least two visits from MIMIC-III, a publicly available dataset that contains realistic medical data of 53,423 adults (age ≥ 16 years) in the ICU of Beth Israel medical center, Massachusetts Institute of technology from 2001 to 2012, as well as clinical data of 7787 newborns from 2001 to 2008. To avoid sparse data problem, we follow previous studies [5, 30] to remove laboratory test items appearing less than 50 times and keep the most frequent 200 medications. Finally, as shown in Table 1, we obtain a dataset of 7121 patients, where there are 11,269,796 laboratory test items of 462 types, 1,030,083 medications of 200 types. On average, a patient has 2.68 visits, each visit has 590 laboratory test items (lab test) and 54 prescriptions, each prescription has 6.26 medications, the time interval between two visit is 2.26 years, and the time interval between two prescriptions is 1 day. We randomly split the selected dataset into a training set, a development set and a test set in patients with a ratio of 80:10:10. The statistics of the dataset are listed in Table 2, where #* is the number of *.

Table 1 Statistics of patient data
Table 2 Statistics of the dataset used in this study

4.2 Evaluation metrics

As next-period medication prediction is a class imbalance problem, we utilize the Area Under ROC (Receiver Operating Characteristic) curve (AUC) and Area Under Precision-Recall Curve (AUPR) to evaluate model performance. The AUC is the primary metric.

4.3 Experimental settings

We start with the basic LSTM considering different types of medical events separately, doctor AI [20], T-LSTM [21], decomposed LSTM [5] and GNN, and then compare different RGNN variants, denoted by RGNN-G-C, where G∈{“co-occurrence graph (CG)”, “time-aware graph (TG)”} represents graph construction method and C∈{“concatenation mechanism (CAT)”, “attention mechanism (ATT)”, “gate mechanism” (GAT)} represents the way to combine the representations from decomposed LSTM and GNN. As the same as RGNN, GNN also contains two variants according to the graph construction methods, which are denoted by GNN-CG and GNN-TG. All methods are implemented in TensorFlow and all models are trained three times with the following hyperparameters:

  • Batch size: mini-batch

  • Optimizer: stochastic Adam

  • Learning rate: 1 × 10−3

  • Epoch: 1000

  • Hidden states in decomposed LSTM: 128

  • Hidden states in GNN (in Eq. 9): 128

  • Hidden states in ReLU in Fig. 2: 256

5 Results

A comparison of our method with other methods is shown in Table 3. Among the methods used for comparison, decomposed LSTM and GNN designed to fuse different types of medical events outperform the other LSTM variants that considers different types of medical events separately by more about 2% in AUC. Compared to decomposed LSTM and GNN, RGNN that combines them obtains an AUC improvement of 0.9% to 2.9%. In the case of RGNN, RGNN using TG is a little better than RGNN using CG, RGNN using ATT or GAT achieves much better performance than RGNN using CAT. The AUC difference between RGNN using ATT or GAT and RGNN using CAT is around 0.8%. RGNN-TG-ATT achieves the highest AUC of 0.8387 and the highest AUPR of 0.2848.

Table 3 Comparison of our method with other methods

Moreover, we investigate the performance of RGNN-TG-ATT on the top 10 diagnoses (as shown in Table 4). The AUCs of RGNN-TG-ATT on all the top 10 diagnoses are higher than 0.81. The highest AUC is 0.8739 on chest pain, and the lowest AUC is 0.8130 on altered mental status. RGNN-TG-ATT achieves the highest AUPR of 0.3471 on coronary artery disease and the lowest AUPR of 0.2427 on upper gastrointestinal bleeding.

Table 4 Performance of RGNN-TG-ATT on the top 10 diagnoses

6 Discussion

In this study, we investigate two methods for patient longitudinal and heterogeneous medical data representation, that is, decomposed LSTM and GNN. A common characteristic of them is that the two methods are specially designed to fuse different types of medical information that are dependent on each other. As we mentioned in the method section, decomposed LSTM is used to represent a patient status sequence, while GNN is used to represent the temporal medical event graph. The former mainly focuses on patient status at each time, and the latter one mainly focuses on the whole medical events. This may be the main reason that both decomposed LSTM and GNN outperform the basic LSTM that considers different types of medical information separately as shown in Table 3. Because patient medical data is time-sequential data, RGNN using TG is a little better than RGNN using CG. To assess the effect of time on GNN, we further compare RGNN-TG-ATT with RGNN-CG-ATT on the top 10 diagnoses as shown in Table 5, where time volatility is the standard deviation of time intervals divided by the mean of time intervals. In most cases, the larger the time volatility, the effect the time on AUC. For example, as the largest time volatility of 2.8871 brings the biggest AUC difference of 0.0161, and the smallest time volatility of 2.3731 brings the smallest AUC difference of − 0.0003. From Table 4, we find that the effect of time on GNN is not very remarkable. The reason may be that the times in the data used in this study are relatively concentrated as the data comes from the ICU of Beth Israel medical center, Massachusetts Institute of technology. In the case of the combination of the two representations from RNN and GNN, the reason why both the attention mechanism and the gate mechanism show better performance than the concatenation mechanism is that the two representations share some information and are complementary to each other.

Table 5 Effect of time on GNN

Although the proposed method shows promising performance, there may be three directions for further improvement: (1) considering more types of medical information such as diagnosis sequence beside laboratory test sequence and prescription sequence; (2) introducing values into the one-hot vectors of medical history at each time; (3) integrating medical knowledge.

7 Conclusions

In this study, a novel hybrid method of RNN and GNN, called RGNN, is proposed for next-period prescription prediction. RGNN represents patient longitudinal medical data from two views, that is, RNN for patient status sequence representation and GNN for temporal medical event graph representation. Experiments on a benchmark dataset prove the effectiveness of RGNN on next-period prescription prediction. RGNN can also be applied to other medical event prediction tasks.