PythonでPostgresデータから決定木を構築する
「HackerNews翻訳してみた」が POSTD (ポスト・ディー) としてリニューアルしました!この記事はここでも公開されています。
Original article: Building a Decision Tree in Python from Postgres data by Gary Sieling
今回は、任意の人物の所得を人口統計データを使って予測する手法をご紹介します。使用するのは20年前の人口統計データです。
この例を用いて、関係データベースの情報から予測モデルを導き出す方法と、その途中で起こり得るトラブルについて触れたいと思います。
このデータの優れた点は、データの作成者が下記のようなアルゴリズムの精度をデータに添付している点です。こうした数値はスモークテストの結果評価に役立ちます。
Algorithm Error -- ---------------- ----- 1 C4.5 15.54 2 C4.5-auto 14.46 3 C4.5 rules 14.94 4 Voted ID3 (0.6) 15.64 5 Voted ID3 (0.8) 16.47 6 T2 16.84 7 1R 19.54 8 NBTree 14.10 9 CN2 16.00 10 HOODG 14.82 11 FSS Naive Bayes 14.05 12 IDTM (Decision table) 14.46 13 Naive-Bayes 16.12 14 Nearest-neighbor (1) 21.42 15 Nearest-neighbor (3) 20.35 16 OC1 15.04
このデータセットをPostgresのデータベースに読み込むには、まずデータ作成者のファイルの一番下にある空白行と「1×0 Cross Validator」と書かれた行を削除する必要があります。
次に、下記の手法でデータを読み込みます。テストデータも同じPostgresデータベースにロードしていますが、そこは気にしないでください。
ご覧の通り、PostgresではUNCパスを指定することができます。特筆するほどのことではないと思われるかもしれませんが、VisualStudioではUNCパスを読み込むことができないことを考えると、うれしい機能ではないでしょうか。
DROP TABLE income_trn; CREATE TABLE income_trn (age INTEGER, workclass text, fnlwgt INTEGER, education text, education_num INTEGER, marital_status text, occupation text, relationship text, race text, sex text, capital_gain INTEGER, capital_loss INTEGER, hours_per_week INTEGER, native_country text, category text); COPY income_trn FROM '\\\\nas\\Files\\Data\\income\\adult.data' DELIMITER ',' CSV; DROP TABLE income_test; CREATE TABLE income_test (age INTEGER, workclass text, fnlwgt INTEGER, education text, education_num INTEGER, marital_status text, occupation text, relationship text, race text, sex text, capital_gain INTEGER, capital_loss INTEGER, hours_per_week INTEGER, native_country text, category text); COPY income_test FROM '\\\\nas\\Files\\Data\\income\\adult.test' DELIMITER ',' CSV;
Pythonの場合、こうしたデータならSQLAlchemyを使っても読み込むことができます。ただしPostgresドライバ("pg8000")は不安定なため、たまに次のようなエラーが起こる場合があります。
ProgrammingError: (ProgrammingError) ('ERROR', '34000', 'portal "pg8000_portal_12" does not exist') None None
エラーの原因はさまざまですが、そのひとつとしてPostgresの旧バージョンを使用しているケースが考えられます (著者は9.3を使用しました)。旧バージョンには閉じているカーソルのデータが読み込まれてしまうという問題もあるようです。
from sqlalchemy import * engine = create_engine( "postgresql+pg8000://postgres:postgres@localhost/pacer", isolation_level="READ UNCOMMITTED" ) c = engine.connect() meta = MetaData() income_trn = Table('income_trn', meta, autoload=True, autoload_with=engine) income_test = Table('income_test', meta, autoload=True, autoload_with=engine)
大量のデータを処理する場合、クエリの結果をモデルの中にストリーム処理する手法はとても有効です。ただし今回はデータサイズが小さいので、ストリーム処理は行いませんでした。また、ひとつの表に全データが入っている場合には、効率よくデータ処理ができるよう、任意にデータを半分に分割する方法を考えなければならないでしょう。
from sqlalchemy.sql import select def get_data(table): s = select([table]) result = c.execute(s) return [row for row in result] test_data = get_data(income_trn) trn_data = get_data(income_test)
このデータには、もともとテキスト型の列と整数型の列が混ざって入っていました(職業と年齢など)。意外にもPythonの機械学習ライブラリは、このような混合データの認識が苦手なようです(少なくともディシジョンツリーは苦手です)。こうしたデータは一連のvalue値のみで構成されたデータとはまったく別物なので、特別な配慮が必要になります。
問題は、ライブラリがvalue値のリストを期待しているにも関わらず元データが数値型である、というケースです。この問題を解決するには、次のようなマッピングを行うグローバルな辞書の構築が必要になるでしょう。
maxVal = 0 vals = dict() rev_vals = dict() def f(x): global maxVal global vals if (not x in vals): maxVal = maxVal + 1 vals[x] = maxVal rev_vals[maxVal] = x return vals[x]
ここで、属性を2つに分割しなければなりません。ひとつは出力に、もうひとつは出力を予測するための属性に分割します。
def get_selectors(data): return [ [f(x) for x in t[0:-1]] for t in data] def get_predictors(data): return [0 if "<" in t[14] else 1 for t in data] trn = get_selectors(trn_data) trn_v = get_predictors(trn_data)
この事例で最も注目すべきは、なんといってもモデルの作成が驚くほど簡単なことです。例を見てみましょう。
from sklearn import tree clf = tree.DecisionTreeRegressor() clf = clf.fit(trn, trn_v)
結局、テストメソッドは自前で実装することになりました。混同行列は、クラスに定義されたデータの計算が得意ではないようですね。
test = get_selectors(test_data) test_v = get_predictors(test_data) testsRun = 0 testsPassed = 0 for t in test: if clf.predict(t) == test_v[testsRun]: testsPassed = testsPassed + 1 testsRun = testsRun + 1 100 * testsPassed / testsRun DecisionTreeClassifier: 78% DecisionTreeRegressor: 79%
最後に、scikit-learnのドキュメントをチェックしてみてください。すべての事例にステキな図表がついていますね。ただ、ドキュメントを読めば分かりますが、ディシジョンツリーはかなり長くなる可能性があります。何千というルールが適用されることもあるので、よほどシンプルなケースでない限り図表化には向かないでしょう。