chainer 2.0勉強記録

チュートリアルを一通りやった上で,追加でハマったことのメモ

Datasetの作成方法

以下のように,DatasetMixinを使って作る.データの生成(ファイル読み込みなど)をget_exampleまで遅延させるのもOK

class Dataset(chainer.dataset.DatasetMixin):
    def __init__(self):
        xs = np.array(np.random.uniform(-math.pi, math.pi, (10000, 1)), dtype=np.float32)
        f = lambda t: np.array([math.sin(t)], dtype=np.float32)
        ys = np.array([f(e) for e in xs])
        self.input = xs
        self.output = ys
    def __len__(self):
        return len(self.output)
    def get_example(self, i):
        return self.input[i], self.output[i]

Trainerに与えるModelの実装方法

__call__の引数

第一引数が入力,第二引数が出力(actual)となり,loss関数を実装する.

__call__での引数の扱い

上記Datasetの場合,引数がただの変数なので,Variable(...)でラップする必要がある.

その他

計算途中を変数に置かないと上手く動かないケースがある(? 未検証)

例えば,

h1 = l1(x)
h2 = l2(h1)
return h2

は上手く動くが,

return l2(l2(x))

は上手く動かないような気がする.