This tutorial walks through training NeoML classification model to classify the well-known News20 data set.
We are going to use the combination of linear binary classifier and the "one versus all" method.
We assume that the data set is split into two parts: train and test, and each is serialized in a file on disk as a CMemoryProblem (which is a simple implementation of the IProblem interface provided in the library).
The library serialization methods can be used to load the data into memory for processing.
CPtr<CMemoryProblem> trainData = new CMemoryProblem();
CPtr<CMemoryProblem> testData = new CMemoryProblem();
CArchiveFile trainFile( "news20.train", CArchive::load );
CArchive trainArchive( &trainFile, CArchive::load );
trainArchive >> trainData;
CArchiveFile testFile( "news20.test", CArchive::load );
CArchive testArchive( &testFile, CArchive::load );
testArchive >> testData;The "one versus all" classifier uses the specified binary classifier to train a model per each class that would determine the probability for an object to belong to this class. An input object is then classified by the models voting.
- Create a linear binary classifier using the
CLinearBinaryClassifierBuilderclass. Select the logistic regression loss function (EF_LogRegconstant). - Create a
COneVersusAllclassifier, passing the binary classifier set up on the previous step to the constructor. - Call the
Trainmethod, passing thetrainDatatraining set prepared above. The method will train the model and return it as an object implementing theIModelinterface.
CLinearBinaryClassifierBuilder linear( EF_LogReg );
COneVersusAll oneVersusAll( linear );
CPtr<IModel> model = oneVersusAll.Train( *trainData );We can check the results the trained model shows on the test sample using the Classify method of the IModel interface. Call this method for each vector of the testData data set prepared before.
int correct = 0;
for( int i = 0; i < testData->GetVectorCount(); i++ ) {
CClassificationResult result;
model->Classify( testData->GetVector( i ), result );
if( result.PreferredClass == testData->GetClass( i ) ) {
correct++;
}
}
double totalResult = static_cast<double>(correct) / testData->GetVectorCount();
printf("%.3f\n", totalResult);On this testing run, 83.3% of the vectors were classified correctly.
0.833