
学習時の中間モデル保存時、
「{epoch:03d}-{val_loss:.5f}.hdf5」のようなデータがいっぱいでき上がってしまう。
ちょうど「Best3だけ残す」みたいな事ができればいいのに!
という事を実現する方法を考えました。
はじめに
Modelを学習をさせたときに、
以下のようにCheckpointがたまっていき、
容量くって仕方ないとか、スッキリしたいときに使えます。
「最新の3つだけ残したい」みたいな技です。
動作確認環境
Tensorflow 2.1 + Python3.6
(Windows 10 Home Anaconda環境)
TF2.0系なら動作すると思います。
(Kerasでも同様の手法で可能)
ソースコード一式
(2/29更新)
実際に動かした方が早い、ソース読んだ方が早いわーな方に。
Githubにアップしましたので参考までにどうぞ。
解決法 説明
次のCallback関数を作成し、学習時に呼び出すことで解決します。
- 「ModelCheckpoint」で生成されるファイル名と同一ファイル名を生成
- Checkpointの保存ファイル履歴をとっておく
- on_epoch_end 内で、履歴がN個を超えたら最古のファイルを削除
これを「CheckpointToolsクラス」と命名してコード化。
class CheckpointTools(Callback):
def __init__(self, save_best_only=True, num_saves=3):
self.last_val_loss = float("inf") # save_best_only判定用
self.save_best_only = save_best_only
assert num_saves >= 1
self.num_saves = num_saves # 最大保存数(この数を超えたら最古を消す)
self.recent_files = [] # ファイル履歴
また、”save_best_only”にも対応させます。
肝心の、ファイル履歴保存と、ファイル削除関数です。
def remove_oldest_file(self):
if len(self.recent_files) > self.num_saves:
file_name = self.recent_files.pop(0) # 先頭ファイルパス取得
if os.path.exists(file_name):
os.remove(file_name) # ファイル削除
print('remove:'+file_name)
# 毎epoch 終了時に呼び出されるCallback
def on_epoch_end(self, epoch, logs={}):
val_loss = logs['val_loss']
# ModelCheckpointのファイル名に合わせる
# ※epoch=(epoch+1) に注意
file_name = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5').format(epoch=(epoch+1), val_loss=val_loss)
print('store:'+file_name)
if self.save_best_only:
if val_loss < self.last_val_loss:
self.last_val_loss = val_loss
self.recent_files.append(file_name)
self.remove_oldest_file()
else:
# ファイル履歴追加
self.recent_files.append(file_name)
# 古いファイル削除
self.remove_oldest_file()
そして、ModelCheckpoint で生成するfilepath と、
on_epoch_endでのfile_nameをあわせます。
# Checkpoint作成設定
check_point = ModelCheckpoint(filepath = os.path.join(CP_DIR, 'epoch{epoch:03d}-{val_loss:.5f}.hdf5'), monitor='val_loss', verbose=1, save_best_only=SAVE_BEST_ONLY, mode='auto')
cb_funcs.append(check_point)
# 上で設定したCheckpointToolsをCallbackに組み込む
cb_cptools = CheckpointTools(save_best_only=SAVE_BEST_ONLY, num_saves=3)
cb_funcs.append(cb_cptools)
注意すべきは、
on_epoch_end で受け取れる”epoch”は、
「学習中のepoch数 – 1」が返される(1エポック目なら0)ため、
実際にはModelCheckpointでの”epoch”と1ずれています。
そのため、
on_epoch_end ではformat(epoch=(epoch+1)・・・ としています。
動作結果
上記例だと、以下のように3つ以外自動削除されます。
CheckpointTools にわたす num_saves を変更することで
残す数を指定できます。
さいごに
最後まで読んでいただき、ありがとうございます!
ブックマーク登録、
ツイッターフォロー、
よろしくお願いいたします!🙇♂️🙇♂️
↓↓↓




