Este es un ejemplo de «Hola mundo» de aprendizaje automático en Java. Simplemente le da una idea del aprendizaje automático en Java.
Ambiente
Java 1.6+ y Eclipse
Paso 1: descargue la biblioteca de Weka
Página de descarga: http://www.cs.waikato.ac.nz/ml/weka/snapshots/weka_snapshots.html
Descargue stable.XX.zip, descomprima el archivo, agregue weka.jar a la ruta de su biblioteca del proyecto Java en Eclipse.
Paso 2: preparar los datos
Cree un archivo txt «weather.txt» siguiendo el siguiente formato:
@relation weather @attribute outlook {sunny, overcast, rainy} @attribute temperature numeric @attribute humidity numeric @attribute windy {TRUE, FALSE} @attribute play {yes, no} @data sunny,85,85,FALSE,no sunny,80,90,TRUE,no overcast,83,86,FALSE,yes rainy,70,96,FALSE,yes rainy,68,80,FALSE,yes rainy,65,70,TRUE,no overcast,64,65,TRUE,yes sunny,72,95,FALSE,no sunny,69,70,FALSE,yes rainy,75,80,FALSE,yes sunny,75,70,TRUE,yes overcast,72,90,TRUE,yes overcast,81,75,FALSE,yes rainy,71,91,TRUE,no
Este conjunto de datos es del paquete de descarga de weka. Se encuentra en «/data/weather.numeric.arff». El nombre de la extensión del archivo es «arff», pero simplemente podemos usar «txt».
Paso 3: entrenamiento y pruebas con Weka
Este ejemplo de código utiliza un conjunto de clasificadores proporcionados por Weka. Entrena el modelo en el conjunto de datos dado y prueba utilizando una validación cruzada dividida en 10. Explicaré cada clasificador más adelante ya que es un tema más complicado.
import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.evaluation.NominalPrediction; import weka.classifiers.rules.DecisionTable; import weka.classifiers.rules.PART; import weka.classifiers.trees.DecisionStump; import weka.classifiers.trees.J48; import weka.core.FastVector; import weka.core.Instances; public class WekaTest { public static BufferedReader readDataFile(String filename) { BufferedReader inputReader = null; try { inputReader = new BufferedReader(new FileReader(filename)); } catch (FileNotFoundException ex) { System.err.println("File not found: " + filename); } return inputReader; } public static Evaluation classify(Classifier model, Instances trainingSet, Instances testingSet) throws Exception { Evaluation evaluation = new Evaluation(trainingSet); model.buildClassifier(trainingSet); evaluation.evaluateModel(model, testingSet); return evaluation; } public static double calculateAccuracy(FastVector predictions) { double correct = 0; for (int i = 0; i < predictions.size(); i++) { NominalPrediction np = (NominalPrediction) predictions.elementAt(i); if (np.predicted() == np.actual()) { correct++; } } return 100 * correct / predictions.size(); } public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) { Instances[][] split = new Instances[2][numberOfFolds]; for (int i = 0; i < numberOfFolds; i++) { split[0][i] = data.trainCV(numberOfFolds, i); split[1][i] = data.testCV(numberOfFolds, i); } return split; } public static void main(String[] args) throws Exception { BufferedReader datafile = readDataFile("weather.txt"); Instances data = new Instances(datafile); data.setClassIndex(data.numAttributes() - 1); // Do 10-split cross validation Instances[][] split = crossValidationSplit(data, 10); // Separate split into training and testing arrays Instances[] trainingSplits = split[0]; Instances[] testingSplits = split[1]; // Use a set of classifiers Classifier[] models = { new J48(), // a decision tree new PART(), new DecisionTable(),//decision table majority classifier new DecisionStump() //one-level decision tree }; // Run for each model for (int j = 0; j < models.length; j++) { // Collect every group of predictions for current model in a FastVector FastVector predictions = new FastVector(); // For each training-testing split pair, train and test the classifier for (int i = 0; i < trainingSplits.length; i++) { Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]); predictions.appendElements(validation.predictions()); // Uncomment to see the summary for each training-testing pair. //System.out.println(models[j].toString()); } // Calculate overall accuracy of current classifier on all splits double accuracy = calculateAccuracy(predictions); // Print current classifier's name and accuracy in a complicated, // but nice-looking way. System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": " + String.format("%.2f%%", accuracy) + "n---------------------------------"); } } } |
La vista de paquete de su proyecto debería tener el siguiente aspecto:
Referencias:
1. http://www.cs.umb.edu/~ding/history/480_697_spring_2013/homework/WekaJavaAPITutorial.pdf
2. http://www.cs.ru.nl/P.Lucas/teaching/DM/weka.pdf