Theano scan"/>
Theano scan
本次的博客内容是根据自己对theano的scan官方教程的总结。点击scan官方教程
scan函数在theano中提供循环迭代。
scan的函数签名如下:
theano.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)
其中fn可以是对序列进行迭代出来lamda表达式也可以是函数, sequeces是要进行的迭代序列,outputs_info是对每次跌代要使用的前一次迭代的在开始时的初始化数据。non_sequences则是非序列的迭代使用的参数,strict禁用不在non_sequences中指定就可以访问之前定义的shared变量从而优化代码。allow_gc设置为false可以禁用scan里的垃圾回收,从而优化代码。
对于fn指定的函数,它的参数有着顺序要求,即squences, prior results, non_squences.
如果多个squence的长度不一样,迭代的次数则为最短的序列的长度。
outputs_info的类型与每次的迭代的返回的类型相同,即使是能够隐式的转换也行。可通过如下代码实现
# outputs_info = T.as_tensor_variable(0)outputs_info = T.as_tensor_variable(np.asarray(0, seq.dtype)) #seq 具有和迭代结果相同的类型
在scan使用到taps values时, a_tm2代表a(t-2) , b_tp3代表b(t+3)
Note the order in which the parameters are given, and in which the result is returned. Try to respect chronological order among the taps ( time slices of sequences or outputs) used. For scan is crucial only for the variables representing the different time taps to be in the same order as the one in which these taps are given. Also, not only taps should respect an order, but also variables, since this is how scan figures out what should be represented by what
一个示例代码如下:
def oneStep(u_tm4, u_t, x_tm3, x_tm1, y_tm1, W, W_in_1, W_in_2, W_feedback, W_out):x_t = T.tanh(theano.dot(x_tm1, W) + \theano.dot(u_t, W_in_1) + \theano.dot(u_tm4, W_in_2) + \theano.dot(y_tm1, W_feedback))y_t = theano.dot(x_tm3, W_out)return [x_t, y_t]W = T.matrix()
W_in_1 = T.matrix()
W_in_2 = T.matrix()
W_feedback = T.matrix()
W_out = T.matrix()u = T.matrix() # it is a sequence of vectors
x0 = T.matrix() # initial state of x has to be a matrix, since# it has to cover x[-3]
y0 = T.vector() # y0 is just a vector since scan has only to provide# y[-1]([x_vals, y_vals], updates) = theano.scan(fn=oneStep,sequences=dict(input=u, taps=[-4,-0]),outputs_info=[dict(initial=x0, taps=[-3,-1]), y0],non_sequences=[W, W_in_1, W_in_2, W_feedback, W_out],strict=True)# for second input y, scan adds -1 in output_taps by default
优化scan代码的方法有
- Minimizing Scan usage
- Explicitly passing inputs of the inner function to scan
- Deactivating garbage collecting in Scan
- Graph optimizations
更多推荐
Theano scan
发布评论