Welcome to the Extended Neural Network project! This repository demonstrates how we move beyond a simple autoencoder-based neural network for modeling student responses, by introducing cluster-based imputation of missing data using a question correlation matrix derived from subject metadata and TF-IDF similarity. Our improvements mitigate the effects of extremely sparse data and increase predictive accuracy for student performance.
- Overview
- Key Techniques
- Methodology
- Implementation
- Usage
- Performance and Results
- Limitations
- References and Acknowledgments
Motivation
Our dataset is extremely sparse: roughly
Contributions
-
Label Shifting: We shift all incorrect answers from
$0$ to$-1$ , preserving$0$ explicitly for missing entries only. This helps the model differentiate truly incorrect answers ($-1$ ) from simply missing data ($0$ ). -
Question Correlation Matrix: We generate a matrix
$C_Q$ that captures how similar or correlated two questions are, based on their subject metadata (via TF-IDF). - K-Means Clustering: For each question cluster, we estimate a student's missing responses using a simple majority-voting (mean-based) approach within that cluster.
- Extended Autoencoder: Our base autoencoder is enhanced by plugging in these more meaningful imputed values and by incorporating an additional regularization term.
-
Autoencoder
- A two-layer linear network (
$g$ and$h$ ) with sigmoid activations. - Learns a compressed (
$k$ -dimensional) representation of student responses and reconstructs the input vector.
- A two-layer linear network (
-
Imputation via Question Similarity
- Subject correlation matrix
$C_S$ built using TF-IDF on subject names. - Question correlation matrix
$C_Q = A C_S A^{T}$ , where $ A $ is the question-subject assignment matrix. - Normalizing
$C_Q$ ensures question similarities lie between$0$ and$1$ .
- Subject correlation matrix
-
K-Means
- We cluster questions into
$k_{\text{means}}$ groups based on$C_Q^{\text{normalized}}$ . - Missing responses for each student are filled using a mean-based decision from the questions in the same cluster.
- We cluster questions into
-
Regularization
- We include a weight-decay penalty
$\lambda$ on the autoencoder’s weights. - Improves generalization and stabilizes training.
- We include a weight-decay penalty
-
Load Sparse Matrix: We begin with a matrix
$X \in \mathbb{R}^{N \times Q}$ , where$N$ is the number of students and$Q$ is the number of questions. Originally, many entries are NaN. -
Label Shifting:
-
$-1$ for incorrect answers, -
$0$ for missing answers, -
$+1$ for correct answers.
-
We use a simple autoencoder comprising:
-
Encoder:
$\mathbf{z} = \sigma \bigl(W_1 \cdot \mathbf{x} \bigr)$ -
Decoder:
$\hat{\mathbf{x}} = \sigma \bigl(W_2 \cdot \mathbf{z} \bigr)$
where
Given:
-
Question-Subject Assignment Matrix
$A \in \mathbb{R}^{Q \times S}$ , with$A_{q,s} = 1$ if question$q$ is linked to subject$s$ , else$0$ . -
Subject Correlation Matrix
$C_S \in \mathbb{R}^{S \times S}$ computed via TF-IDF similarity on subject names.
We define:
We then normalize
To impute missing entries:
- Clustering:
Each question is assigned to one cluster
-
Missing Entry Imputation:
For student$s_m$ and question$q_n$ with a missing response ($X_{m,n} = \text{NaN}$ ):- Find the cluster
$\mathcal{C}_k$ to which$q_n$ belongs. - Let
$\mathcal{A}_k^{(m)}$ be the set of valid (non-missing) answers student$s_m$ has for all questions in$\mathcal{C}_k$ . - If
$\mathcal{A}_k^{(m)}$ is non-empty, compute the mean$\mu_m^{(k)}$ . If$\mu_m^{(k)} > 0$ , impute$+1$ , else$-1$ . If$\mathcal{A}_k^{(m)}$ is empty, assign$0$ .
- Find the cluster
Loss Function
We minimize the reconstruction loss (sum of squared errors) plus a weight-decay term:
We use Stochastic Gradient Descent (SGD) with a chosen learning rate (e.g., lr = 0.005) and run for num_epoch = 80 epochs.
Below are two main scripts in this repository:
-
File:
nn.py(or a similar name) -
Purpose: Implements the autoencoder with zero-based imputation (filling missing entries with
$0$ ) and trains across various latent dimensions$k$ .
Key points:
AutoEncoder(nn.Module):- Forward uses two linear layers (
self.g,self.h), each followed by a sigmoid. - train function includes reconstruction loss on known entries plus weight norm penalty.
- Forward uses two linear layers (
class AutoEncoder(nn.Module):
def __init__(self, num_question, k=100):
super(AutoEncoder, self).__init__()
self.g = nn.Linear(num_question, k)
self.h = nn.Linear(k, num_question)
def forward(self, inputs):
hidden = torch.sigmoid(self.g(inputs))
out = torch.sigmoid(self.h(hidden))
return out-
File:
extended_nn.py(or similar) -
Purpose: Improves data preprocessing by:
-
Computing the question correlation matrix
$C_Q$ via TF-IDF-based subject similarity$C_S$ . -
Clustering questions with K-Means (
$k_{\text{means}}=14$ by default). -
Imputing missing entries in
$\mathbf{X}$ based on cluster membership. - Training the same autoencoder architecture but on the updated training matrix with a possible extra correlation-based regularization term.
-
Computing the question correlation matrix
Steps:
-
get_correlation_matrix(...)- Loads subject metadata and computes
$C_S$ viacosine_similarityon TF-IDF. - Returns
$C_Q^{\text{normalized}}$ .
- Loads subject metadata and computes
-
load_data(k_mean=14, ...)- Reads the question metadata, subject meta, and merges them via an assignment matrix
$A$ . - Applies
KMeansto cluster questions. - Imputes student responses.
- Reads the question metadata, subject meta, and merges them via an assignment matrix
-
AutoEncoderremains structurally the same. -
train(...)function includes the usual reconstruction loss and an optional decoder-correlation penalty.
decoder_weights = model.h.weight
reg_term = torch.trace(
torch.matmul(
torch.matmul(decoder_weights.t(), C_Q_tensor),
decoder_weights
)
).clamp(min=0)-
Install Dependencies
- PyTorch
- scikit-learn for KMeans
- pandas for CSV loading
- matplotlib & seaborn for visualization
-
Prepare Data
- Place your CSV files and sparse data files in a
./data/directory or supply a custom path in the code.
- Place your CSV files and sparse data files in a
-
Train the Extended Model
- Run
python extended_nn.py. - Hyperparameters (learning rate,
$\lambda$ ,k_mean) can be modified at the top of the script.
- Run
-
Check Results
- Training and validation losses/accuracies are printed each epoch.
- Final test accuracy is displayed upon completion.
- Extended Model Accuracy: With our cluster-based imputation + label shifting + correlation-based regularization, we achieve a test accuracy of ~0.6997, outperforming the base neural network variants.
- Visualizations:
Comparative Table
Below is a summary of test accuracies for multiple approaches:
| Model | Test Accuracy |
|---|---|
| User-based kNN | 0.6890 |
| Item-based kNN | 0.6894 |
| IRT | 0.6994 |
| Base NN (no regularization) | 0.6808 |
| Base NN (with regularization) | 0.6861 |
| Extended NN (no regularization) | 0.6991 |
| Extended NN (with reg.) | 0.6997 |
-
Subject/Question Similarity
- Our method assumes similarity among all questions linked to the same subjects, neglecting their varying difficulty. A student might be able to handle easier questions but fail advanced ones in the same subject cluster.
-
Ignoring Additional Metadata
- We do not incorporate extra student attributes (e.g., age, subscription status, etc.). Such features might further enhance personalized predictions.
-
High Computational Cost
- With large-scale data, the cluster-based imputation can be expensive. We handle
$\sim 900{,}000$ missing entries, which can be time-consuming to fill based on cluster membership.
- With large-scale data, the cluster-based imputation can be expensive. We handle
- PyTorch for autoencoder implementation.
- scikit-learn for KMeans clustering and cosine similarity.
- pandas for data wrangling and CSV handling.
- matplotlib & seaborn for plotting metrics and heatmaps.
We hope this extended neural network approach, incorporating more intelligent imputation via question correlation, demonstrates a practical improvement over naive zero-based fill strategies in highly sparse educational data.
