【TF2】Tensorflow学習時に、増え続けるCheckpointデータを自動削除する方法

TIPSディープラーニング
スポンサーリンク

学習時の中間モデル保存時、
「{epoch:03d}-{val_loss:.5f}.hdf5」のようなデータがいっぱいでき上がってしまう。
ちょうど「Best3だけ残す」みたいな事ができればいいのに!
という事を実現する方法を考えました。

はじめに

Modelを学習をさせたときに、
以下のようにCheckpointがたまっていき、
容量くって仕方ないとか、スッキリしたいときに使えます。
「最新の3つだけ残したい」みたいな技です。

動作確認環境

Tensorflow 2.1 + Python3.6
(Windows 10 Home Anaconda環境)

TF2.0系なら動作すると思います。
(Kerasでも同様の手法で可能)

ソースコード一式

(2/29更新)
実際に動かした方が早い、ソース読んだ方が早いわーな方に。
Githubにアップしましたので参考までにどうぞ。

GitHub - MaxiParadise/CheckpointAutoRemove
Contribute to MaxiParadise/CheckpointAutoRemove development by creating an account on GitHub.

 

解決法 説明

次のCallback関数を作成し、学習時に呼び出すことで解決します。

  1. 「ModelCheckpoint」で生成されるファイル名と同一ファイル名を生成
  2. Checkpointの保存ファイル履歴をとっておく
  3. on_epoch_end 内で、履歴がN個を超えたら最古のファイルを削除

 

これを「CheckpointToolsクラス」と命名してコード化。

class CheckpointTools(Callback):
    def __init__(self, save_best_only=True, num_saves=3):
        self.last_val_loss = float("inf")    # save_best_only判定用
        self.save_best_only = save_best_only
        assert num_saves >= 1
        self.num_saves = num_saves    # 最大保存数(この数を超えたら最古を消す)
        self.recent_files = []        # ファイル履歴

また、”save_best_only”にも対応させます。

肝心の、ファイル履歴保存と、ファイル削除関数です。

    def remove_oldest_file(self):
        if len(self.recent_files) > self.num_saves:
            file_name = self.recent_files.pop(0)  # 先頭ファイルパス取得
            if os.path.exists(file_name):
                os.remove(file_name)              # ファイル削除
            print('remove:'+file_name)

    # 毎epoch 終了時に呼び出されるCallback
    def on_epoch_end(self, epoch, logs={}):
        val_loss = logs['val_loss']

        # ModelCheckpointのファイル名に合わせる
        # ※epoch=(epoch+1) に注意
        file_name = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5').format(epoch=(epoch+1), val_loss=val_loss)
        print('store:'+file_name)

        if self.save_best_only:
            if val_loss < self.last_val_loss:
                self.last_val_loss = val_loss
                self.recent_files.append(file_name)
                self.remove_oldest_file()
        else:
            # ファイル履歴追加
            self.recent_files.append(file_name)
            # 古いファイル削除
            self.remove_oldest_file()

そして、ModelCheckpoint で生成するfilepath と、
on_epoch_endでのfile_nameをあわせます。

# Checkpoint作成設定
check_point = ModelCheckpoint(filepath = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5'), monitor='val_loss', verbose=1, save_best_only=SAVE_BEST_ONLY, mode='auto')
cb_funcs.append(check_point)

# 上で設定したCheckpointToolsをCallbackに組み込む
cb_cptools = CheckpointTools(save_best_only=SAVE_BEST_ONLY, num_saves=3)
cb_funcs.append(cb_cptools)

 

注意すべきは、
on_epoch_end で受け取れる”epoch”は、
「学習中のepoch数 – 1」が返される(1エポック目なら0)ため、

実際にはModelCheckpointでの”epoch”と1ずれています。

そのため、
on_epoch_end ではformat(epoch=(epoch+1)・・・ としています。

動作結果

上記例だと、以下のように3つ以外自動削除されます。
CheckpointTools にわたす num_saves を変更することで
残す数を指定できます。

 

さいごに

最後まで読んでいただき、ありがとうございます!

ブックマーク登録、
ツイッターフォロー、
よろしくお願いいたします!🙇‍♂️🙇‍♂️
↓↓↓

タイトルとURLをコピーしました