【Unity】ML.NETの機械学習モデルで画像を分類する方法を紹介します!
ML.NETの機械学習モデルで画像を分類する方法を紹介します。今回は動物のイラスト画像をML.NETの機械学習モデルで分類します。任意の画像が何に分類されるか、複数の動物がいたらどうなるのか、さらにはAIっぽくリアルタイム学習して間違った評価を正しい評価へ学習する方法などを紹介したいと思います。
開発環境は以下のリンクで紹介したものを再利用しますので一度ご覧下さい。
画像サンプルの用意
任意のフォルダー以下に分類したいラベル名のフォルダーを用意します。今回は「魚」「犬」「鳥」「猫」に分類します。
			各フォルダーになるべく多くのサンプル画像を追加します。ML.NETの画像分類の機械学習モデルはJPEGとPNGしか対応していないので注意して下さい。
			
			
			
			以上でサンプル画像の準備は終了です。
モデルビルダーで機械学習モデルを追加
C#クラスライブラリプロジェクトに機械学習モデルを追加します。
			シナリオの選択で「画像の分類」を選択します。
			トレーニング環境の選択でローカル(CPU)を選択します。
			データの追加画面でサンプル画像フォルダーを選択します。
			トレーニング画面でトレーニングを開始して下さい。
			適当な画像を評価して結果を確認して下さい。
			使用タブでコードスニペットをコピーします。
			最後にx64でReleaseビルドしてDLLと機械学習データ(ImageModel.mlnet)が生成されたことを確認して下さい。
Unityへ組み込んで実行
Unityのプロジェクトルートフォルダーに学習データ(ImageModel.mlnet)をコピーして、ビルドしたフォルダー内の全てのDLLをAssetsフォルダー以下にコピーして下さい。
続いてスクリプトImageClassification.csを作成します。Startメソッドにコピーしておいたコードスニペットをペーストしてresult.PredictedLabelをログ表示します。
using MyMLApp;
using System.IO;
using UnityEngine;
public class ImageClassification : MonoBehaviour
{
    void Start()
    {
        //Load sample data
        var imageBytes = File.ReadAllBytes(@"W:\サンプル画像\犬\dog_afghan_hound.png");
        ImageModel.ModelInput sampleData = new ImageModel.ModelInput()
        {
            ImageSource = imageBytes,
        };
        //Load model and predict output
        var result = ImageModel.Predict(sampleData);
        Debug.Log(result.PredictedLabel);
    }
}
			犬のサンプル画像はこちらです。
			Unityを実行してコンソールに「犬」と表示されることを確認して下さい。
		複数の候補を取得する方法
ImageModel.Predict関数ではベストな1つの候補を取得できましたが、複数の動物がいる画像から複数の候補を得たい場合はどうすれば良いのでしょうか。例えば猫が魚をくわえて走っている画像で、猫と魚という答えが欲しい場合です。
			ImageModel.PredictAllLabels関数を利用するとすべてのラベルの評価情報がソートされた状態で得られます。
using MyMLApp;
using System.IO;
using UnityEngine;
public class ImageClassification : MonoBehaviour
{
    void Start()
    {
        //Load sample data
        var imageBytes = File.ReadAllBytes(@"W:\cat_fish_run.png");
        ImageModel.ModelInput sampleData = new ImageModel.ModelInput()
        {
            ImageSource = imageBytes,
        };
        //Load model and predict output
        var labels = ImageModel.PredictAllLabels(sampleData);
        foreach(var label in labels)
        {
            Debug.Log(label.Key + " " + (int)(100 * label.Value) + "%");
        }
    }
}
			このスクリプトを実行すると魚が87%、猫が11%と上位の評価を得ることができました。
			続いて猫と犬が一緒にいる画像はどうでしょうか。
			猫が68%、犬が28%と上位の評価になりました。
			複数の候補を取得したい場合はImageModel.PredictAllLabels関数で上位候補を利用しましょう。
リアルタイム学習する方法
ここまで画像の分類について紹介してきましたが全て期待した評価を得ることができました。しかしながら、間違った評価をされた場合はどうすれば良いでしょうか。このセクションでは間違った評価をリアルタイム学習して正しい評価を導き出す方法について説明します。ML.NETのモデルビルダーで生成されたデータとソースコードだけでは不十分なので自力で実装していきます。
今回利用する変数について簡単に説明します。MLContextは全ての共通コンテキストで、様々な機能を提供します。model変数は「機械学習モデル」のモデルで、学習したり評価したりするときに使われる変数です。DataViewSchemaはデータ構造を表すスキーマで、学習データを保存するときに使います。PredictionEngineは評価を担当するクラスです。
MLContext context;
ITransformer model;
DataViewSchema schema;
PredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput> engine;
			最初に機械学習するTrain関数を実装します。LoadImagesFromDirectory関数はサンプル画像が収められてるルートディレクトリ以下の全ての画像からラベルと画像データを持った入力モデルデータを返します。
void Train()
{
    context = new MLContext();
    var images = LoadImagesFromDirectory(SampleRootDirectory);
    IDataView imageData = context.Data.LoadFromEnumerable(images);
    schema = imageData.Schema;
    model = ImageModel.RetrainModel(context, imageData);
    engine = context.Model.CreatePredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput>(model);
}
IEnumerable<ImageModel.ModelInput> LoadImagesFromDirectory(string directory)
{
    // directory以下全てのファイルを列挙
    var files = Directory.GetFiles(directory, "*", SearchOption.AllDirectories);
    foreach (var file in files)
    {
        // JPEGかPNGのみ
        if (file.EndsWith(".jpg") || file.EndsWith(".png"))
        {
            // 親ディレクトリ名がラベル名
            var label = Directory.GetParent(file).Name;
            yield return new ImageModel.ModelInput
            {
                Label = label,
                ImageSource = File.ReadAllBytes(file)
            };
        }
    }
}
			続いて学習データを保存するSave関数を実装します。
void Save()
{
    context.Model.Save(model, schema, MachineLearningFile);
}
			最後に学習データを読み込むLoad関数を実装します。
void Load()
{
    context = new MLContext();
    model = context.Model.Load(MachineLearningFile, out schema);
    engine = context.Model.CreatePredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput>(model);
}
			以上でリアルタイム学習するのに必要な機能を実装することができました。
リアルタイム学習のデモプログラム
任意の画像を読み込んで誤った評価をされたものをリアルタイム学習して正しい評価にするデモプログラムを作成します。プロジェクトのルートフォルダーにSampleDataというフォルダーを作ってその下に各ラベル名のフォルダー「魚、犬、鳥、猫」にサンプル画像を配置します。間違った評価結果を得やすくするために各フォルダー5枚のみのサンプル画像にしました。
			①学習ボタンで学習データを作成します。
②フィールドに任意の画像パスを入力して読み込みボタンで画像を読み込みます。
③評価ボタンで評価します。鳥と判断されました。
④鳥は誤った評価なので「犬に分類」ボタンを押して犬だと教えます。(犬フォルダに画像をコピーしてます)
⑤学習ボタンで再学習します。
⑥評価ボタンで再評価します。
			⑦犬と再評価されたのでセーブボタンで学習データを保存します。
ここでUnityを再実行します。
			⑧ロードボタンで保存した学習データを読み込みます。
⑨フィールドに先程の画像パスを入力して読み込みボタンで画像を読み込みます。
⑩評価ボタンを押して犬になることを確認します。
作成したリアルタイム学習するスクリプトは以下の通りです。
RealtimeTraining.cs
using Microsoft.ML;
using MyMLApp;
using System.Collections.Generic;
using System.IO;
using UnityEngine;
public class RealtimeTraining : MonoBehaviour
{
    const string MachineLearningFile = "RealtimeTraining.mlnet";
    const string SampleRootDirectory = "SampleData/";
    MLContext context;
    ITransformer model;
    DataViewSchema schema;
    PredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput> engine;
    string path;
    byte[] data;
    Texture2D texture;
    string result;
    private void OnGUI()
    {
        path = GUI.TextField(new Rect(0, 0, 300, 20), path);
        if (GUI.Button(new Rect(300, 0, 100, 20), "読み込み"))
        {
            if(File.Exists(path))
            {
                data = File.ReadAllBytes(path);
                texture = new Texture2D(1, 1);
                texture.LoadImage(data);
                texture.Apply();
            }
        }
        if(texture)
        {
            GUI.DrawTexture(new Rect(0, 20, 400, 400), texture);
        }
        if(GUI.Button(new Rect(0, 420, 100, 20), "ロード"))
        {
            Load();
        }
        if (GUI.Button(new Rect(100, 420, 100, 20), "セーブ"))
        {
            Save();
        }
        if (GUI.Button(new Rect(200, 420, 100, 20), "学習"))
        {
            Train();
        }
        if (GUI.Button(new Rect(300, 420, 100, 20), "評価"))
        {
            ImageModel.ModelInput sampleData = new ImageModel.ModelInput()
            {
                ImageSource = data,
            };
            result = engine.Predict(sampleData).PredictedLabel;
        }
        if(!string.IsNullOrEmpty(result))
        {
            GUI.Label(new Rect(0, 440, 400, 20), "この画像は " + result + " です");
            GUI.Label(new Rect(0, 460, 120, 20), "間違ってますか?");
            if(GUI.Button(new Rect(120, 460, 80, 20), "魚に分類"))
            {
                File.Copy(path, SampleRootDirectory + "魚/" + Path.GetFileName(path));
            }
            if(GUI.Button(new Rect(200, 460, 80, 20), "犬に分類"))
            {
                File.Copy(path, SampleRootDirectory + "犬/" + Path.GetFileName(path));
            }
            if(GUI.Button(new Rect(280, 460, 80, 20), "鳥に分類"))
            {
                File.Copy(path, SampleRootDirectory + "鳥/" + Path.GetFileName(path));
            }
            if(GUI.Button(new Rect(360, 460, 80, 20), "猫に分類"))
            {
                File.Copy(path, SampleRootDirectory + "猫/" + Path.GetFileName(path));
            }
        }
    }
    void Train()
    {
        context = new MLContext();
        var images = LoadImagesFromDirectory(SampleRootDirectory);
        IDataView imageData = context.Data.LoadFromEnumerable(images);
        schema = imageData.Schema;
        model = ImageModel.RetrainModel(context, imageData);
        engine = context.Model.CreatePredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput>(model);
    }
    void Load()
    {
        context = new MLContext();
        model = context.Model.Load(MachineLearningFile, out schema);
        engine = context.Model.CreatePredictionEngine<ImageModel.ModelInput, ImageModel.ModelOutput>(model);
    }
    void Save()
    {
        context.Model.Save(model, schema, MachineLearningFile);
    }
    IEnumerable<ImageModel.ModelInput> LoadImagesFromDirectory(string directory)
    {
        // directory以下全てのファイルを列挙
        var files = Directory.GetFiles(directory, "*", SearchOption.AllDirectories);
        foreach (var file in files)
        {
            // JPEGかPNGのみ
            if (file.EndsWith(".jpg") || file.EndsWith(".png"))
            {
                // 親ディレクトリ名がラベル名
                var label = Directory.GetParent(file).Name;
                yield return new ImageModel.ModelInput
                {
                    Label = label,
                    ImageSource = File.ReadAllBytes(file)
                };
            }
        }
    }
}
		ML.NETの機械学習モデルで画像を分類する方法を紹介しました。
画像を利用した面白いリアルタイムコンテンツを作れそうですね。
				学習すればするほど賢くなっていくのも機械学習の魅力ですね。
関連ページ
こちらのページも合わせてご覧下さい。