本ページでは、R と Python それぞれで CART 法による決定木分析の手順について紹介します。
本例では、あくまで両環境でのコードや分析結果の表示についての比較の目的で行ったため、分析手法の説明や学習や検証における詳細なパラメータ設定は行っていないことをあらかじめご了承ください。
R での実装例
rpart パッケージを用いて実行できます。
ソースコード
1 2 3 4 5 6 7 8 9 |
library(rpart) data(iris) trainset <- iris[1:149,] testset <- iris[150,][-5] (dt.model <- rpart(Species ~ ., data=trainset)) predict(dt.model, testset, type="class") |
Python での実装例
scikit-learn の sklearn.tree.DecisionTreeClassifier
クラスを用いて実行できます。
ソースコード
1 2 3 4 5 6 7 8 9 10 |
from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier iris = load_iris() X, y = iris.data[:-1], iris.target[:-1] dt = DecisionTreeClassifier() dt.fit(X, y) print(dt.predict([iris.data[-1]])) |
出力結果
“2” と出力されているのは “Virginica” なので、正しく推定できたことがわかります。