Pytorch – numpy、Pytroch の関数対応表

Pytorch – numpy、Pytroch の関数対応表

概要

numpy ユーザー向けに numpy の関数と Pytorch の関数の対応表をまとめました。

四則演算

比較

剰余、累乗、絶対値

ブール演算

ビット演算

最大、最小

numpy torch torch.Tensor torch.Tensor (inplace)
最大 numpy.amax torch.max Tensor.max
最小 numpy.amin torch.min Tensor.min
最大のインデックス numpy.argmax torch.argmax Tensor.argmax
最小のインデックス numpy.argmin torch.argmin Tensor.argmin
要素ごとの最小 numpy.minimum torch.max Tensor.max
要素ごとの最大 numpy.maximum torch.min Tensor.min

総和、総乗、累積和、累積積

arange、linspace、logspace、meshgrid

numpy torch torch.Tensor torch.Tensor (inplace)
arange numpy.arange torch.arange
linspace numpy.linspace torch.linspace
logspace numpy.logspace torch.logspace
meshgrid numpy.meshgrid torch.meshgrid

三角関数

指数関数、対数関数

その他の関数

numpy torch torch.Tensor torch.Tensor (inplace)
床関数 numpy.floor torch.floor Tensor.floor Tensor.floor_
天井関数 numpy.ceil torch.ceil Tensor.ceil Tensor.ceil_
クリップ numpy.clip torch.clamp Tensor.clamp Tensor.clamp_
符号関数 numpy.sign torch.sign Tensor.sign Tensor.sign_
度 → ラジアン numpy.deg2rad torch.deg2rad
度 → ラジアン numpy.degrees torch.deg2rad
ラジアン → 度 numpy.rad2deg torch.rad2deg
ラジアン → 度 numpy.radians torch.rad2deg
台形積分 numpy.trapz torch.trapz

empty、zeros、ones、full、単位行列

ドット積、クロス積、行列積、テンソル積

numpy torch torch.Tensor torch.Tensor (inplace)
ドット積 numpy.dot torch.dot Tensor.dot
クロス積 numpy.cross torch.cross Tensor.cross
行列積 numpy.matmul torch.matmul Tensor.matmul
テンソル積 numpy.tensordot torch.tensordot

反転、90度回転、シフト、リサイズ

numpy torch torch.Tensor torch.Tensor (inplace)
反転 numpy.flip torch.flip Tensor.flip
90度回転 numpy.rot90 torch.rot90 Tensor.rot90
シフト numpy.roll torch.roll Tensor.roll
リサイズ numpy.resize Tensor.resize_

線形代数

複製、ブロードキャスト

numpy torch torch.Tensor torch.Tensor (inplace)
繰り返す numpy.repeat torch.repeat_interleave Tensor.repeat_interleave
繰り返す numpy.tile Tensor.repeat
ブロードキャスト numpy.broadcast torch.broadcast_tensors
ブロードキャスト numpy.broadcast_arrays torch.broadcast_tensors
ブロードキャスト Tensor.expand
ブロードキャスト Tensor.expand_as

丸め

numpy torch torch.Tensor torch.Tensor (inplace)
最近接丸め numpy.around torch.round Tensor.round Tensor.round_
0に近いほうに丸め numpy.trunc torch.trunc Tensor.trunc Tensor.trunc_

ソート

numpy torch torch.Tensor torch.Tensor (inplace)
ソート numpy.sort torch.sort Tensor.sort
ソート numpy.argsort torch.argsort Tensor.argsort

統計量

numpy torch torch.Tensor torch.Tensor (inplace)
中央値 numpy.median torch.median Tensor.median
平均 numpy.mean torch.mean Tensor.mean
分散 numpy.var torch.var Tensor.var
標準偏差 numpy.std torch.std Tensor.std
標準偏差と平均 torch.std_mean
分散と平均 torch.var_mean

型の変換

numpy torch torch.Tensor torch.Tensor (inplace)
最小の型を取得 numpy.promote_types torch.promote_types
共通の型を取得 numpy.result_type torch.result_type
キャスト可能かどうか numpy.can_cast torch.can_cast
キャスト numpy.ndarray.astype Tensor.type
キャスト numpy.ndarray.astype Tensor.type_as
スカラーに変換 numpy.asscalar Tensor.item
リストに変換 numpy.ndarray.tolist Tensor.tolist
連続した配列に変換 numpy.ascontiguousarray Tensor.contiguous

形状の操作

軸の操作

numpy torch torch.Tensor torch.Tensor (inplace)
軸変更 numpy.transpose Tensor.permute
軸変更 ndarray.T torch.t Tensor.t Tensor.t_
軸変更 numpy.rollaxis
軸変更 numpy.swapaxes torch.transpose Tensor.transpose Tensor.transpose_

結合、分割

numpy torch torch.Tensor torch.Tensor (inplace)
既存の軸で結合 numpy.concatenate torch.cat
新しい軸を追加して結合 numpy.stack torch.stack
分割 numpy.split torch.split Tensor.split
分割 Tensor.unfold

複素数

numpy torch torch.Tensor torch.Tensor (inplace)
実部 numpy.real torch.real
虚部 numpy.imag torch.imag
共役複素数 numpy.conj torch.conj Tensor.conj
複素数かどうか numpy.iscomplex torch.is_complex Tensor.is_complex
偏角 numpy.angle torch.angle Tensor.angle

浮動小数点数

numpy torch torch.Tensor torch.Tensor (inplace)
浮動小数点数 numpy.isfinite torch.isfinite
浮動小数点数 numpy.isinf torch.isinf
浮動小数点数 numpy.isnan torch.isnan

重複処理

numpy torch torch.Tensor torch.Tensor (inplace)
各要素の出現数 numpy.bincount torch.bincount Tensor.bincount
重複削除 numpy.unique torch.unique Tensor.unique

インデックス

numpy torch torch.Tensor torch.Tensor (inplace)
3項間演算子 numpy.where torch.where Tensor.where
非0の要素のインデックス numpy.nonzero torch.nonzero Tensor.nonzero
下三角成分のインデックス numpy.tril_indices torch.tril_indices
上三角成分のインデックス numpy.triu_indices torch.triu_indices
インデックス numpy.indices Tensor.indices

対角成分、上三角成分、下三角成分

ndarray の attribute

numpy torch torch.Tensor torch.Tensor (inplace)
要素数 numpy.ndarray.size torch.numel Tensor.numel
次元数 numpy.ndarray.ndim Tensor.dim
形状 numpy.ndarray.shape Tensor.size
ストライド numpy.ndarray.strides Tensor.stride
データ numpy.ndarray.data Tensor.data
numpy.ndarray.dtype Tensor.dtype

値の設定、取得

numpy torch torch.Tensor torch.Tensor (inplace)
値の設定 numpy.put Tensor.put_
値の取得 numpy.take torch.take Tensor.take
値の取得 numpy.select Tensor.select

乱数

コピー

numpy torch torch.Tensor torch.Tensor (inplace)
コピー numpy.copy Tensor.new_tensor Tensor.copy_
コピー Tensor.clone

オブジェクトからテンソルを作成

numpy torch torch.Tensor torch.Tensor (inplace)
オブジェクト → Tensor numpy.array torch.tensor
numpy → Tensor torch.from_numpy