pytest封神之路第五步 引數化進階

dongfanger發表於2020-09-23

用過unittest的朋友,肯定知道可以藉助DDT實現引數化。用過JMeter的朋友,肯定知道JMeter自帶了4種引數化方式(見參考資料)。pytest同樣支援引數化,而且很簡單很實用。

語法

在《pytest封神之路第三步 精通fixture》和《pytest封神之路第四步 內建和自定義marker》兩篇文章中,都提到了pytest引數化。那麼本文就趁著熱乎,趕緊聊一聊pytest的引數化是怎麼玩的。

@pytest.mark.parametrize

@pytest.mark.parametrize("test_input,expected", [("3+5", 8), ("2+4", 6), ("6*9", 42)])
def test_eval(test_input, expected):
    assert eval(test_input) == expected
  • 可以自定義變數,test_input對應的值是"3+5" "2+4" "6*9",expected對應的值是8 6 42,多個變數用tuple,多個tuple用list

  • 引數化的變數是引用而非複製,意味著如果值是list或dict,改變值會影響後續的test

  • 重疊產生笛卡爾積

    import pytest
    
    
    @pytest.mark.parametrize("x", [0, 1])
    @pytest.mark.parametrize("y", [2, 3])
    def test_foo(x, y):
        pass
    

@pytest.fixture()

@pytest.fixture(scope="module", params=["smtp.gmail.com", "mail.python.org"])
def smtp_connection(request):
    smtp_connection = smtplib.SMTP(request.param, 587, timeout=5)
  • 只能使用request.param來引用

  • 引數化生成的test帶有ID,可以使用-k來篩選執行。預設是根據函式名[引數名]來的,可以使用ids來定義

    // list
    @pytest.fixture(params=[0, 1], ids=["spam", "ham"])
    // function
    @pytest.fixture(params=[0, 1], ids=idfn)
    

    使用--collect-only 命令列引數可以看到生成的IDs。

引數新增marker

我們知道了引數化後會生成多個tests,如果有些test需要marker,可以用pytest.param來新增

marker方式

# content of test_expectation.py
import pytest


@pytest.mark.parametrize(
    "test_input,expected",
    [("3+5", 8), ("2+4", 6), pytest.param("6*9", 42, marks=pytest.mark.xfail)],
)
def test_eval(test_input, expected):
    assert eval(test_input) == expected

fixture方式

# content of test_fixture_marks.py
import pytest


@pytest.fixture(params=[0, 1, pytest.param(2, marks=pytest.mark.skip)])
def data_set(request):
    return request.param
def test_data(data_set):
    pass

pytest_generate_tests

用來自定義引數化方案。使用到了hook,hook的知識我會寫在《pytest hook》中,歡迎關注公眾號dongfanger獲取最新文章。

# content of conf.py


def pytest_generate_tests(metafunc):
    if "test_input" in metafunc.fixturenames:
        metafunc.parametrize("test_input", [0, 1])
# content of test.py


def test(test_input):
    assert test_input == 0
  • 定義在conftest.py檔案中
  • metafunc有5個屬性,fixturenames,module,config,function,cls
  • metafunc.parametrize() 用來實現引數化
  • 多個metafunc.parametrize() 的引數名不能重複,否則會報錯

引數化誤區

在講示例之前,先簡單分享我的菜雞行為。假設我們現在需要對50個介面測試,驗證某一角色的使用者訪問這些介面會返回403。我的做法是,把介面請求全部引數化了,test函式裡面只有斷言,虛擬碼大致如下

def api():
    params = []
    def func():
        return request()
    params.append(func)
    ...


@pytest.mark.parametrize('req', api())
def test():
    res = req()
    assert res.status_code == 403

這樣引數化以後,會產生50個tests,如果斷言失敗了,會單獨標記為failed,不影響其他test結果。咋一看還行,但是有個問題,在迴歸的時候,可能只需要驗證其中部分介面,就沒有辦法靈活的調整,必須全部跑一遍才行。這是一個相對錯誤的示範,至於正確的應該怎麼寫,相信每個人心中都有一個答案,能解決問題就是ok的。我想表達的是,引數化要適當,不要濫用,最好只對測試資料做引數化

實踐

本文的重點來了,引數化的語法比較簡單,實際應用是關鍵。這部分通過11個例子,來實踐一下。示例覆蓋的知識點有點多,建議留大段時間細看。

1.使用hook新增命令列引數--all,"param1"是引數名,帶--all引數時是range(5) == [0, 1, 2, 3, 4],生成5個tests。不帶引數時是range(2)。

# content of test_compute.py


def test_compute(param1):
    assert param1 < 4

# content of conftest.py


def pytest_addoption(parser):
    parser.addoption("--all", action="store_true", help="run all combinations")
def pytest_generate_tests(metafunc):
    if "param1" in metafunc.fixturenames:
        if metafunc.config.getoption("all"):
            end = 5
        else:
            end = 2
        metafunc.parametrize("param1", range(end))

2.testdata是測試資料,包括2組。test_timedistance_v0不帶ids。test_timedistance_v1帶list格式的ids。test_timedistance_v2的ids為函式。test_timedistance_v3使用pytest.param同時定義測試資料和id。

# content of test_time.py
from datetime import datetime, timedelta

import pytest

testdata = [
    (datetime(2001, 12, 12), datetime(2001, 12, 11), timedelta(1)),
    (datetime(2001, 12, 11), datetime(2001, 12, 12), timedelta(-1)),
]


@pytest.mark.parametrize("a,b,expected", testdata)
def test_timedistance_v0(a, b, expected):
    diff = a - b
    assert diff == expected


@pytest.mark.parametrize("a,b,expected", testdata, ids=["forward", "backward"])
def test_timedistance_v1(a, b, expected):
    diff = a - b
    assert diff == expected


def idfn(val):
    if isinstance(val, (datetime,)):
        # note this wouldn't show any hours/minutes/seconds
        return val.strftime("%Y%m%d")


@pytest.mark.parametrize("a,b,expected", testdata, ids=idfn)
def test_timedistance_v2(a, b, expected):
    diff = a - b
    assert diff == expected


@pytest.mark.parametrize(
    "a,b,expected",
    [
        pytest.param(
            datetime(2001, 12, 12), datetime(2001, 12, 11), timedelta(1), id="forward"
        ),
        pytest.param(
            datetime(2001, 12, 11), datetime(2001, 12, 12), timedelta(-1), id="backward"
        ),
    ],
)
def test_timedistance_v3(a, b, expected):
    diff = a - b
    assert diff == expected

3.相容unittest的testscenarios

# content of test_scenarios.py
def pytest_generate_tests(metafunc):
    idlist = []
    argvalues = []
    for scenario in metafunc.cls.scenarios:
        idlist.append(scenario[0])
        items = scenario[1].items()
        argnames = [x[0] for x in items]
        argvalues.append([x[1] for x in items])
    metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class")


scenario1 = ("basic", {"attribute": "value"})
scenario2 = ("advanced", {"attribute": "value2"})


class TestSampleWithScenarios:
    scenarios = [scenario1, scenario2]

    def test_demo1(self, attribute):
        assert isinstance(attribute, str)

    def test_demo2(self, attribute):
        assert isinstance(attribute, str)

4.初始化資料庫連線

# content of test_backends.py
import pytest


def test_db_initialized(db):
    # a dummy test
    if db.__class__.__name__ == "DB2":
        pytest.fail("deliberately failing for demo purposes")

# content of conftest.py
import pytest


def pytest_generate_tests(metafunc):
    if "db" in metafunc.fixturenames:
        metafunc.parametrize("db", ["d1", "d2"], indirect=True)


class DB1:
    "one database object"


class DB2:
    "alternative database object"


@pytest.fixture
def db(request):
    if request.param == "d1":
        return DB1()
    elif request.param == "d2":
        return DB2()
    else:
        raise ValueError("invalid internal test config")

5.如果不加indirect=True,會生成2個test,fixt的值分別是"a"和"b"。如果加了indirect=True,會先執行fixture,fixt的值分別是"aaa"和"bbb"。indirect=True結合fixture可以在生成test前,對引數變數額外處理。

import pytest


@pytest.fixture
def fixt(request):
    return request.param * 3


@pytest.mark.parametrize("fixt", ["a", "b"], indirect=True)
def test_indirect(fixt):
    assert len(fixt) == 3

6.多個引數時,indirect賦值list可以指定某些變數應用fixture,沒有指定的保持原值。

# content of test_indirect_list.py
import pytest


@pytest.fixture(scope="function")
def x(request):
    return request.param * 3


@pytest.fixture(scope="function")
def y(request):
    return request.param * 2


@pytest.mark.parametrize("x, y", [("a", "b")], indirect=["x"])
def test_indirect(x, y):
    assert x == "aaa"
    assert y == "b"

7.相容unittest引數化

# content of ./test_parametrize.py
import pytest


def pytest_generate_tests(metafunc):
    # called once per each test function
    funcarglist = metafunc.cls.params[metafunc.function.__name__]
    argnames = sorted(funcarglist[0])
    metafunc.parametrize(
        argnames, [[funcargs[name] for name in argnames] for funcargs in funcarglist]
    )


class TestClass:
    # a map specifying multiple argument sets for a test method
    params = {
        "test_equals": [dict(a=1, b=2), dict(a=3, b=3)],
        "test_zerodivision": [dict(a=1, b=0)],
    }

    def test_equals(self, a, b):
        assert a == b

    def test_zerodivision(self, a, b):
        with pytest.raises(ZeroDivisionError):
            a / b

8.在不同python直譯器之間測試物件序列化。python1把物件pickle-dump到檔案。python2從檔案中pickle-load物件。

"""
module containing a parametrized tests testing cross-python
serialization via the pickle module.
"""
import shutil
import subprocess
import textwrap

import pytest

pythonlist = ["python3.5", "python3.6", "python3.7"]


@pytest.fixture(params=pythonlist)
def python1(request, tmpdir):
    picklefile = tmpdir.join("data.pickle")
    return Python(request.param, picklefile)


@pytest.fixture(params=pythonlist)
def python2(request, python1):
    return Python(request.param, python1.picklefile)


class Python:
    def __init__(self, version, picklefile):
        self.pythonpath = shutil.which(version)
        if not self.pythonpath:
            pytest.skip("{!r} not found".format(version))
        self.picklefile = picklefile

    def dumps(self, obj):
        dumpfile = self.picklefile.dirpath("dump.py")
        dumpfile.write(
            textwrap.dedent(
                r"""
                import pickle
                f = open({!r}, 'wb')
                s = pickle.dump({!r}, f, protocol=2)
                f.close()
                """.format(
                    str(self.picklefile), obj
                )
            )
        )
        subprocess.check_call((self.pythonpath, str(dumpfile)))

    def load_and_is_true(self, expression):
        loadfile = self.picklefile.dirpath("load.py")
        loadfile.write(
            textwrap.dedent(
                r"""
                import pickle
                f = open({!r}, 'rb')
                obj = pickle.load(f)
                f.close()
                res = eval({!r})
                if not res:
                raise SystemExit(1)
                """.format(
                    str(self.picklefile), expression
                )
            )
        )
        print(loadfile)
        subprocess.check_call((self.pythonpath, str(loadfile)))


@pytest.mark.parametrize("obj", [42, {}, {1: 3}])
def test_basic_objects(python1, python2, obj):
    python1.dumps(obj)
    python2.load_and_is_true("obj == {}".format(obj))

9.假設有個API,basemod是原始版本,optmod是優化版本,驗證二者結果一致。

# content of conftest.py
import pytest


@pytest.fixture(scope="session")
def basemod(request):
    return pytest.importorskip("base")


@pytest.fixture(scope="session", params=["opt1", "opt2"])
def optmod(request):
    return pytest.importorskip(request.param)

# content of base.py


def func1():
    return 1
# content of opt1.py


def func1():
    return 1.0001
# content of test_module.py
def test_func1(basemod, optmod):
    assert round(basemod.func1(), 3) == round(optmod.func1(), 3)

10.使用pytest.param新增marker和id。

# content of test_pytest_param_example.py
import pytest


@pytest.mark.parametrize(
    "test_input,expected",
    [
        ("3+5", 8),
        pytest.param("1+7", 8, marks=pytest.mark.basic),
        pytest.param("2+4", 6, marks=pytest.mark.basic, id="basic_2+4"),
        pytest.param(
            "6*9", 42, marks=[pytest.mark.basic, pytest.mark.xfail], id="basic_6*9"
        ),
    ],
)
def test_eval(test_input, expected):
    assert eval(test_input) == expected

11.使用pytest.raises讓部分test丟擲Error。

from contextlib import contextmanager

import pytest


// 3.7+ from contextlib import nullcontext as does_not_raise
@contextmanager
def does_not_raise():
    yield


@pytest.mark.parametrize(
    "example_input,expectation",
    [
        (3, does_not_raise()),
        (2, does_not_raise()),
        (1, does_not_raise()),
        (0, pytest.raises(ZeroDivisionError)),
    ],
)
def test_division(example_input, expectation):
    """Test how much I know division."""
    with expectation:
        assert (6 / example_input) is not None

簡要回顧

本文先講了引數化的語法,包括marker,fixture,hook方式,以及如何給引數新增marker,然後重點列舉了幾個實戰示例。引數化用好了能節省編碼,達到事半功倍的效果。

參考資料

docs-pytest-org-en-stable

JMeter4種引數化方式,請閱讀公眾號《三道題加油站 (2)》

相關文章