PyTorchメモ
機械学習の『学習』と『推論』
「機械学習」は、大量の学習データを機械に読み込ませ、そのデータを分析することで分類や識別のルールを作ろうというプログラム。
そのプロセスは、「学習」と「推論」の2つに分けられます。
学習はtrainingで推論はinference
IRってなんだ
PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR.
PyTorch 2.0 は、バックエンドがインターフェースするための IR の 2 つのセットを提供します: Core Aten IR および Prims IR。
- IRは、"Intermediate Representation"の略語です。これは、プログラミング言語やフレームワークなどが処理する中間表現形式を指します。 PyTorchの場合、Core ATen IRは、PyTorch Tensorの操作を表現するためのIRです。 Core ATen IRは、PyTorch内部で使用される中間表現であり、最適化やコンパイルなどの高度な機能を可能にするために使用されます。
@torch.no_gradの意味
torch.no_gradはテンソルの勾配の計算を不可にするContext-managerだ。
テンソルの勾配の計算を不可にすることでメモリの消費を減らす事が出来る。
このモデルでは、計算の結果毎にrequires_grad = Falseを持っている。
インプットがrequires_grad=Trueであろうとも。
このContext managerは、ローカルスレッドだ。他のスレッドの計算には影響を及ぼさない。
関数も@torch.no_grad()デコレーターを使用して返り値requires_grad=Falseに出来る。
つまり
with torch.no_grad():
のネストの中で定義した変数は、自動的にrequires_grad=Falseとなる。
なんでこういうことやるかと言うと
に書かれているように
意味としては、評価モード(Dropouts Layers、BatchNorm Layersをスキップ)に切り替えて、
自動微分を無効(勾配計算用パラメータを保存しないNoGrad Mode)にしてから実行することで不要な処理、
無駄なメモリ消費を抑えて推論を実行することができます。
機械学習モデルの学習時は勾配(こうばい)・・英語でgradient(グラディエント)。略してgrad(グラッド)を更新したいわけですが、学習が終わったらそれらは更新する必要がないわけです。
PyTorchでモデルの分散学習を行う方法
ChatGPTに教えてもらった内容なので、間違いあるかも
モデルの分散ストラテジーの選択
分散ストラテジーには、複数のGPUのどのようにモデルを分散させるかを定義するものがあります。PyTorchには、以下の分散ストラテジーがあります。
- DataParallel
- 1つのマシン内の複数のGPUでモデルを並列に処理するストラテジー
- DistributedDataParallel
- 複数のマシンにまたがって複数のGPUでモデルを並列に処理するストラテジー
- PipelineParallel
- モデルを複数のパーツに分割し、それぞれを異なるGPUで実行するストラテジー
お絵描きAIの推論実行だとDataParallelしか可用性がない気がする。以降はDataParallelの話しか書かない。
モデルの分散化
- 分散ストラテジーを選択したら、モデルを分散化する必要があります。
- DataParallelを使用する場合、モデルは単純にDataParallelのラッパーになります。
データの分散化
- データの分散化には、データを複数のGPUに分割し、それぞれのGPUで同じ演算を実行する必要があります
- これを実現するには、PyTorchのデータローダーを使用し、データを分割するためのデータ分割機能を有効にする必要があります
from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader train_sampler = DistributedSampler(train_dataset) train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler)
学習と推論の並列化
一般的には並列化は学習時のものである
- 推論の並列化についてやる方法が見つかったので、 PyTorch(PiPPy) のページに記載