python のローカル変数のスコープと functools.partial

今日のネタは、ちょっとミスって pythonエンバグしてしまったので、自戒の意味を込めたメモです。

失敗コードは下記のようなものです(無関係の所はごっそり削ってます)。 引数 e や n の値に応じて、 go に関数オブジェクト (か None) を設定しています。

def in_stage(e=None, n=None):
	go = None
	if (e != None):
		go = lambda: goto_event_stage(e)
	elif (n != None):
		go = lambda: goto_normal_stage(n)

	...(実際の処理本体。途中で go() を使う)...

これ、何が悪いかというと、下のほうの処理本体部分で e や n を変更してしまうと、goを呼んだ時の挙動が変わってしまうんですね。

	# 関数の引数が n=3 だった場合
	go()   # goto_normal_state(3) が呼ばれる

	n = 100 # でも途中で n を変更してしまうと
	go()   # goto_normal_stage(100) が呼ばれる

この現象は lambda 式ではなく、def による名前持ちの関数であっても起こります。

原因と対処法

python のローカル変数は、関数の引数に定義されていたり、代入を行った場合には、そのスコープ内に作成されます。本件の例では、代入は lambda の中では行われていないので、新規の変数は作成されません。登場するすべてのローカル変数は、親の in_stage のものであり、lambda 内で参照している x は、親スコープの x そのものです。

この辺、若干 Python 固有なので、幾つか例を挙げてみます。

x = 3
def func():
	print x    #「3」が出力される。 この x は、親のx
func()
def func():
	print x    #「3」。実行時点では親のxは値を持っているのでエラーじゃない
x = 3
func()
def func():
	x = 5   # func 内で x を代入したので、
	        # func 内の x と親の x は別物になった
	print x    # 「5」
x = 3
func()
print x   # 「3」親の x は変更されていない

まあ、ここまでは良い。ここからがややこしい。

x = 3
def func():
	print x   # エラー: x は未初期化。なぜなら……
	x = 5     # ここで代入しているので func 内での x はすべて
	          # 親の x とは別物になったので。
func()
x = 3
def func():
	x += 1   # 同上。エラー: x は未初期化。 この文は x = x + 1 と同義だが、
	         # 「代入しているのでxは全部別物」で、「まだ設定してないxを読んでいる」のでエラー
	print x
func()

名前のバインドの扱いが他の言語と若干違うんですよね。代入されたタイミングで新しくバインドされるのではなく、そのスコープ全体に遡って適用されるというか。同じスコープ内で、同じ名前が指しているものは同じものという保証がある、というか。

じゃぁ、クロージャ的なものはどうするのよ、って話ですが、「そういった『個々に固有の値を保持するオブジェクトを作りたい』のであれば、ちゃんとオブジェクトを作って、オブジェクトに持たせなさい」というのが、python の流儀なのではないかと。

実装例はこちらの方が書かれています。→ 最もタメになる「初心者用言語」は Python! - 西尾泰和のはてなダイアリー

例が2つ挙げられていますが、いずれの例でも、覚える値はオブジェクトのアトリビュートとして保持しています。

私はどう書くべきだったか

……と、ここで終わってはいけません。
クロージャを書くことが目的ではないのですから。より問題に合った、わかりやすい記述を求めるべきでしょう。

やりたかったのは「go という変数に、goto_normal_stage(3)とかを呼んでくれる関数オブジェクトを設定したい」という物ですが、もし、goto_normal_stage_3() みたいな関数が定義済みであれば

if n = 3: go = goto_normal_stage_3

とでも書けば良かったわけです。
つまり goto_normal_stage_3 のような、「(一部の)引数が設定済みの関数オブジェクト」を動的に作る方法が欲しかった訳です。これは関数型言語で言う所の「部分適用」です。

そのための道具は functools に partial() として既に用意されています。

from functools import partial
def in_stage(e=None, n=None):
	go = None
	if (e != None):
		go = partial(goto_event_stage, e)
	elif (n != None):
		go = partial(goto_normal_stage, n)

これですっきり。