学習時の中間モデル保存時、
「{epoch:03d}-{val_loss:.5f}.hdf5」のようなデータがいっぱいでき上がってしまう。
ちょうど「Best3だけ残す」みたいな事ができればいいのに!
という事を実現する方法を考えました。
はじめに
Modelを学習をさせたときに、
以下のようにCheckpointがたまっていき、
容量くって仕方ないとか、スッキリしたいときに使えます。
「最新の3つだけ残したい」みたいな技です。
動作確認環境
Tensorflow 2.1 + Python3.6
(Windows 10 Home Anaconda環境)
TF2.0系なら動作すると思います。
(Kerasでも同様の手法で可能)
ソースコード一式
(2/29更新)
実際に動かした方が早い、ソース読んだ方が早いわーな方に。
Githubにアップしましたので参考までにどうぞ。
解決法 説明
次のCallback関数を作成し、学習時に呼び出すことで解決します。
- 「ModelCheckpoint」で生成されるファイル名と同一ファイル名を生成
- Checkpointの保存ファイル履歴をとっておく
- on_epoch_end 内で、履歴がN個を超えたら最古のファイルを削除
これを「CheckpointToolsクラス」と命名してコード化。
class CheckpointTools(Callback): def __init__(self, save_best_only=True, num_saves=3): self.last_val_loss = float("inf") # save_best_only判定用 self.save_best_only = save_best_only assert num_saves >= 1 self.num_saves = num_saves # 最大保存数(この数を超えたら最古を消す) self.recent_files = [] # ファイル履歴
また、”save_best_only”にも対応させます。
肝心の、ファイル履歴保存と、ファイル削除関数です。
def remove_oldest_file(self): if len(self.recent_files) > self.num_saves: file_name = self.recent_files.pop(0) # 先頭ファイルパス取得 if os.path.exists(file_name): os.remove(file_name) # ファイル削除 print('remove:'+file_name) # 毎epoch 終了時に呼び出されるCallback def on_epoch_end(self, epoch, logs={}): val_loss = logs['val_loss'] # ModelCheckpointのファイル名に合わせる # ※epoch=(epoch+1) に注意 file_name = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5').format(epoch=(epoch+1), val_loss=val_loss) print('store:'+file_name) if self.save_best_only: if val_loss < self.last_val_loss: self.last_val_loss = val_loss self.recent_files.append(file_name) self.remove_oldest_file() else: # ファイル履歴追加 self.recent_files.append(file_name) # 古いファイル削除 self.remove_oldest_file()
そして、ModelCheckpoint で生成するfilepath と、
on_epoch_endでのfile_nameをあわせます。
# Checkpoint作成設定 check_point = ModelCheckpoint(filepath = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5'), monitor='val_loss', verbose=1, save_best_only=SAVE_BEST_ONLY, mode='auto') cb_funcs.append(check_point) # 上で設定したCheckpointToolsをCallbackに組み込む cb_cptools = CheckpointTools(save_best_only=SAVE_BEST_ONLY, num_saves=3) cb_funcs.append(cb_cptools)
注意すべきは、
on_epoch_end で受け取れる”epoch”は、
「学習中のepoch数 – 1」が返される(1エポック目なら0)ため、
実際にはModelCheckpointでの”epoch”と1ずれています。
そのため、
on_epoch_end ではformat(epoch=(epoch+1)・・・ としています。
動作結果
上記例だと、以下のように3つ以外自動削除されます。
CheckpointTools にわたす num_saves を変更することで
残す数を指定できます。
さいごに
最後まで読んでいただき、ありがとうございます!
ブックマーク登録、
ツイッターフォロー、
よろしくお願いいたします!🙇♂️🙇♂️
↓↓↓