【TF2】Tensorflow2.x でGPU設定をする (旧 allow_growth の呼び替え)

TIPSディープラーニング
スポンサーリンク

よく使う手法なのですが、
TF2.0になり、随分変更があったのでメモに残しておきます。

調査日:2020年1月3日

概要

Tensorflowで、
GPUの使用するメモリを動的確保したり、
複数GPUマシン上の1つだけを指定するなどの
方法。

1.x → 2.0でも変更があり、
さらに2.0→2.1でも変更がありました。

https://github.com/tensorflow/tensorflow/blob/v2.1.0/RELEASE.md#breaking-changes

 

TF version 使用するGPUを指定
Tensorflow2.1 tf.config.set_visible_devices
Tensorflow2.0 tf.config.experimental.set_visible_devices
Tensorflow1.x tf.ConfigProto(tf.GPUOptions( visible_device_list=<GPU ID> ))

 

TF version 動的メモリアロケート設定
Tensorflow2.1 tf.config.experimental.set_memory_growth  2.0と同じ
Tensorflow2.0 tf.config.experimental.set_memory_growth
Tensorflow1.x tf.ConfigProto(tf.GPUOptions( allow_growth=True/False ))

 

 

私はコード共通化のため、
 if tf.__version__ >= “2.1.0”:
・・・
elif tf.__version__ >= “2.0.0”:
・・・
のように分岐して設定しています。

サンプルコード:

gpu_id = 0
print(tf.__version__)
if tf.__version__ >= "2.1.0":
    physical_devices = tf.config.list_physical_devices('GPU')
    tf.config.list_physical_devices('GPU')
    tf.config.set_visible_devices(physical_devices[gpu_id], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu_id], True)
elif tf.__version__ >= "2.0.0":
    #TF2.0
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_visible_devices(physical_devices[gpu_id], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu_id], True)
else:
    from keras.backend.tensorflow_backend import set_session
    config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(
            visible_device_list=str(gpu_id), # specify GPU number
            allow_growth=True
        )
    )
    set_session(tf.Session(config=config))

 

その他は公式API参考。

Module: tf.config  |  TensorFlow v2.15.0.post1
Public API for tf._api.v2.config namespace

 

ちなみに、
set_visible_devices がエラーを返したりする場合、
CUDAのインストールがうまく行っていない事が真っ先に考えられます。
(TF2.0はCUDA10.0、TF2.1からはCUDA10.1使用)

tf.config.list_physical_devices(‘GPU’) が空を返していないか、
確認してみるといいです。

 

お役に立てれば幸いです。
そいではー

タイトルとURLをコピーしました