Improve the pyproject.toml configuration, add first test, and create more type hints

This commit is contained in:
Zoé Cassiopée Gauthier 2024-04-17 18:06:46 -04:00
parent e1cf1cb194
commit 207178f822
9 changed files with 256 additions and 115 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
__pycache__/ __pycache__/
.coverage
dist/ dist/
src/pogo_scaled_estimators/_version.py src/pogo_scaled_estimators/_version.py
*.sqlite *.sqlite

View File

@ -27,7 +27,7 @@ classifiers = [
] ]
[project.urls] [project.urls]
"Source" = "https://git.blorp.dev/zo/pogo-scaled-estimators" Source = "https://git.blorp.dev/zo/pogo-scaled-estimators"
[project.scripts] [project.scripts]
ase-cli = "pogo_scaled_estimators:main_cli" ase-cli = "pogo_scaled_estimators:main_cli"
@ -38,18 +38,46 @@ source = "vcs"
[tool.hatch.build.hooks.vcs] [tool.hatch.build.hooks.vcs]
version-file = "src/pogo_scaled_estimators/_version.py" version-file = "src/pogo_scaled_estimators/_version.py"
[tool.hatch.envs.lint] [tool.hatch.envs.default]
detached = true
dependencies = [ dependencies = [
"coverage[toml]",
"pyright", "pyright",
"pytest",
"requests-mock",
"ruff", "ruff",
] ]
[tool.hatch.envs.lint.scripts] [tool.hatch.envs.default.scripts]
all = ["style", "typing"] test = "pytest {args:tests}"
format = ["ruff format --fix {args:.}"] test-cov = "coverage run -m pytest {args:tests}"
style = ["ruff check {args:.}"] cov-report = [
typing = ["pyright"] "- coverage combine",
"coverage report",
]
cov = [
"test-cov",
"cov-report",
]
format-check = "ruff format --check --diff {args:.}"
format-fix = "ruff format {args:.}"
lint-check = "ruff check {args:.}"
lint-fix = "ruff check --fix {args:.}"
types-check = "pyright"
[tool.coverage.run]
branch = true
source_pkgs = ["pogo_scaled_estimators", "tests"]
[tool.coverage.paths]
pogo_scaled_estimators = ["src/pogo_scaled_estimators", "*/pogo_scaled_estimators/src/pogo_scaled_estimators"]
tests = ["tests", "*/pogo_scaled_estimators/tests"]
[tool.coverage.report]
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
[tool.ruff] [tool.ruff]
target-version = "py312" target-version = "py312"
@ -112,46 +140,18 @@ ignore = [
] ]
[tool.pyright] [tool.pyright]
include = ["src/glimmer", "tests"] include = ["src/pogo_scaled_estimators", "tests"]
exclude = ["**/__pycache__"]
reportMissingImports = true reportMissingImports = true
reportMissingTypeStubs = false reportMissingTypeStubs = false
pythonVersion = "3.12" pythonVersion = "3.12"
pythonPlatform = "Linux" pythonPlatform = "Linux"
typeCheckingMode = "standard" typeCheckingMode = "strict"
strictListInference = true reportMissingParameterType = "none"
strictDictionaryInference = true reportUnknownArgumentType = "none"
strictSetInference = true reportUnknownLambdaType = "none"
reportAssertAlwaysTrue = "error" reportUnknownMemberType = "none"
reportInvalidStringEscapeSequence = "error" reportUnknownParameterType = "none"
reportSelfClsParameterName = "error" reportUnknownVariableType = "none"
reportConstantRedefinition = "error" reportUnusedFunction = "none"
reportDeprecated = "error"
reportDuplicateImport = "error"
reportIncompatibleMethodOverride = "error"
reportIncompatibleVariableOverride = "error"
reportInconsistentConstructor = "error"
reportMatchNotExhaustive = "warning"
reportOverlappingOverload = "error"
reportMissingSuperCall = "error"
reportPrivateUsage = "warning"
reportTypeCommentUsage = "error"
reportUnnecessaryCast = "error"
reportUnnecessaryComparison = "error"
reportUnnecessaryContains = "error"
reportUnnecessaryIsInstance = "error"
reportUnusedClass = "warning"
reportUnusedImport = "warning"
reportUnusedFunction = "warning"
reportUnusedVariable = "warning"
reportUntypedBaseClass = "error"
reportUntypedClassDecorator = "error"
reportUntypedFunctionDecorator = "error"
reportUntypedNamedTuple = "error"
reportCallInDefaultInitializer = "error"
reportImplicitOverride = "error"
reportPropertyTypeMismatch = "warning"
reportShadowedImports = "warning"
reportUninitializedInstanceVariable = "warning"
reportUnnecessaryTypeIgnoreComment = "warning"
reportUnusedCallResult = "warning"

View File

@ -6,11 +6,12 @@
import math import math
from enum import Flag, auto from enum import Flag, auto
from typing import final from typing import TypeGuard, final, get_args
from rich.progress import Progress from rich.progress import Progress
from pogo_scaled_estimators.pokebattler_proxy import MovesetResult, PokebattlerProxy, Raid from pogo_scaled_estimators.pokebattler_proxy import MovesetResult, PokebattlerProxy, Raid
from pogo_scaled_estimators.typing import PokemonType
from pogo_scaled_estimators.utilities import format_move_name, format_pokemon_name from pogo_scaled_estimators.utilities import format_move_name, format_pokemon_name
@ -22,17 +23,25 @@ class Filter(Flag):
DISALLOW_LEGENDARY_POKEMON = auto() DISALLOW_LEGENDARY_POKEMON = auto()
class NotAPokemonTypeError(TypeError):
def __init__(self, pokemon_types: list[str]):
super().__init__(f"{pokemon_types} contains invalid Pokebattler type name")
@final @final
class Calculator: class Calculator:
def __init__(self, attacker_types: list[str], filters: Filter = Filter.NO_FILTER) -> None: def __init__(self, attacker_types: list[str], filters: Filter = Filter.NO_FILTER) -> None:
self.attacker_types = attacker_types if self._is_list_of_types(attacker_types):
self.attacker_types: list[PokemonType] = attacker_types
else:
raise NotAPokemonTypeError(attacker_types)
self.filters = filters self.filters = filters
self._pokebattler_proxy = PokebattlerProxy() self._pokebattler_proxy = PokebattlerProxy()
self._progress = Progress() self._progress = Progress()
def calculate(self, level: int = 40, party: int = 1) -> None: def calculate(self, level: int = 40, party: int = 1) -> None:
raid_bosses = self._pokebattler_proxy.raid_bosses(self.attacker_types) raid_bosses = self._pokebattler_proxy.raid_bosses(self.attacker_types)
attackers = { attackers: dict[str, dict[str, list[MovesetResult]]] = {
attacker: {"RAID_LEVEL_3": [], "RAID_LEVEL_5": [], "RAID_LEVEL_MEGA": []} attacker: {"RAID_LEVEL_3": [], "RAID_LEVEL_5": [], "RAID_LEVEL_MEGA": []}
for attacker in self._pokebattler_proxy.with_charged_moves(self.attacker_types) for attacker in self._pokebattler_proxy.with_charged_moves(self.attacker_types)
if self._allowed_attacker(attacker) if self._allowed_attacker(attacker)
@ -66,6 +75,8 @@ class Calculator:
for attacker, movesets in results.items() for attacker, movesets in results.items()
if movesets if movesets
} }
if not best_movesets:
continue
best_estimator = min(best_movesets.values(), key=lambda moveset: moveset.estimator).estimator best_estimator = min(best_movesets.values(), key=lambda moveset: moveset.estimator).estimator
for attacker, moveset in best_movesets.items(): for attacker, moveset in best_movesets.items():
attackers[attacker][simplified_raid_tier].append(moveset.scale(best_estimator)) attackers[attacker][simplified_raid_tier].append(moveset.scale(best_estimator))
@ -86,6 +97,9 @@ class Calculator:
f"[bold]{format_pokemon_name(attacker, attacker_type)}[/bold] ({format_move_name(fast_move, fast_move_type)}/{format_move_name(charged_move, charged_move_type)}): {ase:.2f}" f"[bold]{format_pokemon_name(attacker, attacker_type)}[/bold] ({format_move_name(fast_move, fast_move_type)}/{format_move_name(charged_move, charged_move_type)}): {ase:.2f}"
) )
def _is_list_of_types(self, attacker_types: list[str]) -> TypeGuard[list[PokemonType]]:
return all(attacker_type in get_args(PokemonType) for attacker_type in attacker_types)
def _allowed_attacker(self, pokemon_id: str) -> bool: def _allowed_attacker(self, pokemon_id: str) -> bool:
if Filter.DISALLOW_MEGA_POKEMON in self.filters and ("_MEGA" in pokemon_id or "_PRIMAL" in pokemon_id): if Filter.DISALLOW_MEGA_POKEMON in self.filters and ("_MEGA" in pokemon_id or "_PRIMAL" in pokemon_id):
return False return False
@ -121,17 +135,19 @@ class Calculator:
return ("MOVE_NONE", "MOVE_NONE") return ("MOVE_NONE", "MOVE_NONE")
return min(movesets, key=lambda moveset: self._ase(moveset_results_by_tier, only=moveset)) return min(movesets, key=lambda moveset: self._ase(moveset_results_by_tier, only=moveset))
def _ase(self, moveset_results_by_tier: dict[str, list[MovesetResult]], only=None) -> float: def _ase(
try: self, moveset_results_by_tier: dict[str, list[MovesetResult]], only: tuple[str, str] | None = None
) -> float:
return ( return (
0.15 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_3"], only) 0.15 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_3"], only)
+ 0.50 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_5"], only) + 0.50 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_5"], only)
+ 0.35 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_MEGA"], only) + 0.35 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_MEGA"], only)
) )
except ZeroDivisionError:
return float("inf")
def _average_estimator(self, moveset_results: list[MovesetResult], only: tuple[str, str] | None = None) -> float: def _average_estimator(self, moveset_results: list[MovesetResult], only: tuple[str, str] | None = None) -> float:
if only: if only:
moveset_results = [m for m in moveset_results if m.fast_move == only[0] and m.charged_move == only[1]] moveset_results = [m for m in moveset_results if m.fast_move == only[0] and m.charged_move == only[1]]
try:
return sum(moveset.estimator for moveset in moveset_results) / len(moveset_results) return sum(moveset.estimator for moveset in moveset_results) / len(moveset_results)
except ZeroDivisionError:
return float("inf")

View File

@ -5,31 +5,47 @@
# https://opensource.org/licenses/MIT. # https://opensource.org/licenses/MIT.
import argparse import argparse
import logging
import operator import operator
import sys import sys
from functools import reduce from functools import reduce
from rich.logging import RichHandler
from pogo_scaled_estimators._version import version from pogo_scaled_estimators._version import version
from pogo_scaled_estimators.calculator import Calculator, Filter from pogo_scaled_estimators.calculator import Calculator, Filter
def main_cli(): def main_cli():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
_ = parser.add_argument("type", nargs="+", help="an attacker type") parser.add_argument("type", nargs="+", help="an attacker type")
_ = parser.add_argument("--level", type=int, default=40) parser.add_argument("--level", type=int, default=40)
_ = parser.add_argument("--party", type=int, default=1) parser.add_argument("--party", type=int, default=1)
_ = parser.add_argument("--no-legacy", dest="filters", action="append_const", const=Filter.DISALLOW_LEGACY_MOVES) parser.add_argument("--no-legacy", dest="filters", action="append_const", const=Filter.DISALLOW_LEGACY_MOVES)
_ = parser.add_argument("--no-mega", dest="filters", action="append_const", const=Filter.DISALLOW_MEGA_POKEMON) parser.add_argument("--no-mega", dest="filters", action="append_const", const=Filter.DISALLOW_MEGA_POKEMON)
_ = parser.add_argument("--no-shadow", dest="filters", action="append_const", const=Filter.DISALLOW_SHADOW_POKEMON) parser.add_argument("--no-shadow", dest="filters", action="append_const", const=Filter.DISALLOW_SHADOW_POKEMON)
_ = parser.add_argument( parser.add_argument(
"--no-legendary", dest="filters", action="append_const", const=Filter.DISALLOW_LEGENDARY_POKEMON "--no-legendary", dest="filters", action="append_const", const=Filter.DISALLOW_LEGENDARY_POKEMON
) )
_ = parser.add_argument("--version", action="version", version=version) parser.add_argument(
"-v", "--verbose", action="store_const", dest="log_level", const=logging.DEBUG, default=logging.WARNING
)
parser.add_argument("--version", action="version", version=version)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=args.log_level, handlers=[RichHandler()])
log = logging.getLogger()
requests_log = logging.getLogger("urllib3")
requests_log.setLevel(args.log_level)
requests_log.propagate = True
filters = reduce(operator.or_, args.filters or [], Filter.NO_FILTER) filters = reduce(operator.or_, args.filters or [], Filter.NO_FILTER)
try:
calculator = Calculator(args.type, filters) calculator = Calculator(args.type, filters)
calculator.calculate(level=args.level, party=args.party) calculator.calculate(level=args.level, party=args.party)
except Exception:
log.exception("Could not calculate ASE.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,10 +7,12 @@
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import cast, final from typing import NotRequired, TypedDict, cast, final
import requests_cache import requests_cache
from pogo_scaled_estimators.typing import PokemonType
BASE_URL = "https://fight.pokebattler.com" BASE_URL = "https://fight.pokebattler.com"
WEAKNESS = 1.6 WEAKNESS = 1.6
DOUBLE_WEAKNESS = WEAKNESS * WEAKNESS DOUBLE_WEAKNESS = WEAKNESS * WEAKNESS
@ -27,7 +29,7 @@ class Raid:
@dataclass @dataclass
class Move: class Move:
move_id: str move_id: str
typing: str typing: PokemonType
@dataclass @dataclass
@ -36,37 +38,58 @@ class MovesetResult:
charged_move: str charged_move: str
estimator: float estimator: float
def scale(self, factor: float): def scale(self, factor: float) -> "MovesetResult":
return MovesetResult(self.fast_move, self.charged_move, self.estimator / factor) return MovesetResult(self.fast_move, self.charged_move, self.estimator / factor)
class PokebattlerMove(TypedDict):
moveId: str
type: PokemonType
class PokebattlerPokemon(TypedDict):
pokemonId: str
type: PokemonType
type2: NotRequired[PokemonType]
quickMoves: list[str]
cinematicMoves: list[str]
movesets: list[dict[str, str]]
class PokebattlerRaid(TypedDict):
pokemonId: str
class PokebattlerRaidTierInfo(TypedDict):
guessTier: str
class PokebattlerRaidTier(TypedDict):
tier: str
info: PokebattlerRaidTierInfo
raids: list[PokebattlerRaid]
@final @final
class PokebattlerProxy: class PokebattlerProxy:
def __init__(self): def __init__(self, log_level="INFO"):
self._cached_session = requests_cache.CachedSession("pokebatter_cache", cache_control=True) self._cached_session = requests_cache.CachedSession("pokebatter_cache", cache_control=True, use_cache_dir=True)
self._pokemon: dict | None = None
self._raids: dict | None = None
self._resists: dict | None = None
@property
def cached_session(self):
return self._cached_session
@cached_property @cached_property
def moves(self) -> dict: def moves(self) -> list[PokebattlerMove]:
return self.cached_session.get(f"{BASE_URL}/moves").json()["move"] return self._cached_session.get(f"{BASE_URL}/moves").json()["move"]
@cached_property @cached_property
def pokemon(self) -> dict: def pokemon(self) -> list[PokebattlerPokemon]:
return self.cached_session.get(f"{BASE_URL}/pokemon").json()["pokemon"] return self._cached_session.get(f"{BASE_URL}/pokemon").json()["pokemon"]
@cached_property @cached_property
def raids(self) -> dict: def raids(self) -> list[PokebattlerRaidTier]:
return self.cached_session.get(f"{BASE_URL}/raids").json() return self._cached_session.get(f"{BASE_URL}/raids").json()["tiers"]
@cached_property @cached_property
def resists(self) -> dict: def resists(self) -> dict[str, list[float]]:
return self.cached_session.get(f"{BASE_URL}/resists").json() return self._cached_session.get(f"{BASE_URL}/resists").json()
def simulate(self, raid: Raid) -> dict[str, list[MovesetResult]]: def simulate(self, raid: Raid) -> dict[str, list[MovesetResult]]:
query_string = { query_string = {
@ -85,7 +108,8 @@ class PokebattlerProxy:
url = f"{BASE_URL}/raids/defenders/{raid.defender}/levels/{raid.tier}/attackers/levels/{raid.level}/strategies/CINEMATIC_ATTACK_WHEN_POSSIBLE/DEFENSE_RANDOM_MC?{urllib.parse.urlencode(query_string, doseq=True)}" url = f"{BASE_URL}/raids/defenders/{raid.defender}/levels/{raid.tier}/attackers/levels/{raid.level}/strategies/CINEMATIC_ATTACK_WHEN_POSSIBLE/DEFENSE_RANDOM_MC?{urllib.parse.urlencode(query_string, doseq=True)}"
response = self._cached_session.get(url) response = self._cached_session.get(url)
results: dict[str, list[MovesetResult]] = {} results: dict[str, list[MovesetResult]] = {}
for attacker in response.json()["attackers"][0]["randomMove"]["defenders"]: response_json = response.json()
for attacker in response_json["attackers"][0]["randomMove"]["defenders"]:
results[attacker["pokemonId"]] = [ results[attacker["pokemonId"]] = [
MovesetResult( MovesetResult(
attacker_moves["move1"], attacker_moves["move2"], cast(float, attacker_moves["result"]["estimator"]) attacker_moves["move1"], attacker_moves["move2"], cast(float, attacker_moves["result"]["estimator"])
@ -94,9 +118,9 @@ class PokebattlerProxy:
] ]
return results return results
def raid_bosses(self, attacker_types: list[str]) -> dict: def raid_bosses(self, attacker_types: list[PokemonType]) -> dict[str, list[str]]:
raid_tiers = [] raid_tiers: list[str] = []
raid_bosses = {} raid_bosses: dict[str, list[str]] = {}
for raid_level in ["3", "5", "MEGA", "MEGA_5", "ULTRA_BEAST"]: for raid_level in ["3", "5", "MEGA", "MEGA_5", "ULTRA_BEAST"]:
tier = f"RAID_LEVEL_{raid_level}" tier = f"RAID_LEVEL_{raid_level}"
@ -109,21 +133,12 @@ class PokebattlerProxy:
) )
raid_bosses[tier] = [] raid_bosses[tier] = []
for tier in filter(lambda tier: tier["tier"] in raid_tiers, self.raids["tiers"]): for tier in filter(lambda tier: tier["tier"] in raid_tiers, self.raids):
for boss in (raid["pokemon"] for raid in tier["raids"]): for boss in (raid["pokemonId"] for raid in tier["raids"]):
if boss.endswith("_FORM"): if boss.endswith("_FORM"):
continue continue
boss_pokemon = next(filter(lambda mon: mon["pokemonId"] == boss, self.pokemon)) boss_pokemon: PokebattlerPokemon = next(filter(lambda mon: mon["pokemonId"] == boss, self.pokemon))
if ("candyToEvolve" in boss_pokemon or boss in ["SEADRA", "SEALEO"]) and boss not in [ boss_types: tuple[PokemonType, PokemonType] = (
"KELDEO",
"LUMINEON",
"MANAPHY",
"PHIONE",
"STUNFISK",
"TERRAKION",
]:
continue
boss_types = (
boss_pokemon["type"], boss_pokemon["type"],
boss_pokemon.get("type2", "POKEMON_TYPE_NONE"), boss_pokemon.get("type2", "POKEMON_TYPE_NONE"),
) )
@ -132,7 +147,7 @@ class PokebattlerProxy:
return raid_bosses return raid_bosses
def _is_weak(self, attacker_type: str, defender_types: tuple[str, str]) -> bool: def _is_weak(self, attacker_type: PokemonType, defender_types: tuple[PokemonType, PokemonType]) -> bool:
pokemon_types = list(self.resists.keys()) pokemon_types = list(self.resists.keys())
defender_type_indices = ( defender_type_indices = (
pokemon_types.index(defender_types[0]), pokemon_types.index(defender_types[0]),
@ -156,8 +171,8 @@ class PokebattlerProxy:
return False return False
def with_charged_moves(self, attacker_types: list[str]) -> list[str]: def with_charged_moves(self, attacker_types: list[PokemonType]) -> list[str]:
charged_moves = [ charged_moves: list[str] = [
move["moveId"] move["moveId"]
for move in self.moves for move in self.moves
if "moveId" in move and "type" in move and move["type"] in attacker_types if "moveId" in move and "type" in move and move["type"] in attacker_types
@ -168,10 +183,10 @@ class PokebattlerProxy:
if any(moveset["cinematicMove"] in charged_moves for moveset in mon["movesets"]) if any(moveset["cinematicMove"] in charged_moves for moveset in mon["movesets"])
] ]
def find_pokemon(self, name: str) -> dict: def find_pokemon(self, name: str) -> PokebattlerPokemon:
return next(filter(lambda mon: mon["pokemonId"] == name, self.pokemon)) return next(filter(lambda mon: mon["pokemonId"] == name, self.pokemon))
def pokemon_type(self, name: str) -> str: def pokemon_type(self, name: str) -> PokemonType:
return self.find_pokemon(name)["type"] return self.find_pokemon(name)["type"]
def find_move(self, move_id: str) -> Move: def find_move(self, move_id: str) -> Move:

View File

@ -0,0 +1,29 @@
# Copyright 2024 Zoé Cassiopée Gauthier.
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
from typing import Literal
PokemonType = Literal[
"POKEMON_TYPE_NONE",
"POKEMON_TYPE_BUG",
"POKEMON_TYPE_DARK",
"POKEMON_TYPE_DRAGON",
"POKEMON_TYPE_ELECTRIC",
"POKEMON_TYPE_FAIRY",
"POKEMON_TYPE_FIGHTING",
"POKEMON_TYPE_FIRE",
"POKEMON_TYPE_FLYING",
"POKEMON_TYPE_GHOST",
"POKEMON_TYPE_GRASS",
"POKEMON_TYPE_GROUND",
"POKEMON_TYPE_ICE",
"POKEMON_TYPE_NORMAL",
"POKEMON_TYPE_POISON",
"POKEMON_TYPE_PSYCHIC",
"POKEMON_TYPE_ROCK",
"POKEMON_TYPE_STEEL",
"POKEMON_TYPE_WATER",
]

View File

@ -4,7 +4,11 @@
# license that can be found in the LICENSE file or at # license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT. # https://opensource.org/licenses/MIT.
POKEMON_TYPE_COLORS = { from typing import Final
from pogo_scaled_estimators.typing import PokemonType
POKEMON_TYPE_COLORS: Final[dict[PokemonType, str]] = {
"POKEMON_TYPE_BUG": "green_yellow", "POKEMON_TYPE_BUG": "green_yellow",
"POKEMON_TYPE_DARK": "bright_black", "POKEMON_TYPE_DARK": "bright_black",
"POKEMON_TYPE_DRAGON": "dodger_blue2", "POKEMON_TYPE_DRAGON": "dodger_blue2",
@ -27,7 +31,7 @@ POKEMON_TYPE_COLORS = {
MINIMUM_SPECIAL_NAME_PARTS = 2 MINIMUM_SPECIAL_NAME_PARTS = 2
def format_pokemon_name(name: str, pokemon_type: str | None = None): def format_pokemon_name(name: str, pokemon_type: PokemonType | None = None):
parts = [part.capitalize() for part in name.split("_")] parts = [part.capitalize() for part in name.split("_")]
if parts[-1] == "Mega" or parts[-1] == "Primal": if parts[-1] == "Mega" or parts[-1] == "Primal":
parts = [parts[-1]] + parts[:-1] parts = [parts[-1]] + parts[:-1]
@ -41,7 +45,7 @@ def format_pokemon_name(name: str, pokemon_type: str | None = None):
return formatted_name return formatted_name
def format_move_name(name, move_type: str | None = None): def format_move_name(name: str, move_type: PokemonType | None = None):
parts = [part.capitalize() for part in name.split("_")] parts = [part.capitalize() for part in name.split("_")]
if parts[-1] == "Fast": if parts[-1] == "Fast":
parts = parts[:-1] parts = parts[:-1]

5
tests/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# Copyright 2024 Zoé Cassiopée Gauthier.
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

View File

@ -0,0 +1,55 @@
# Copyright 2024 Zoé Cassiopée Gauthier.
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import unittest.mock
import pytest
import requests
from pogo_scaled_estimators.pokebattler_proxy import MovesetResult, PokebattlerProxy, Raid
class MockSession(requests.Session):
def __init__(self, *args, **kwargs):
kwargs.pop("cache_control", None)
kwargs.pop("use_cache_dir", None)
super().__init__(*[], **kwargs)
@pytest.fixture(autouse=True)
def _disable_requests_cache():
"""Replace CachedSession with a regular Session for all test functions"""
with unittest.mock.patch("requests_cache.CachedSession", MockSession):
yield
def test_simulate(requests_mock):
requests_mock.get(
"https://fight.pokebattler.com/raids/defenders/MEWTWO_SHADOW_FORM/levels/RAID_LEVEL_5_SHADOW/attackers/levels/40/strategies/CINEMATIC_ATTACK_WHEN_POSSIBLE/DEFENSE_RANDOM_MC",
json={
"attackers": [
{
"randomMove": {
"move1": "RANDOM",
"move2": "RANDOM",
"defenders": [
{
"pokemonId": "BIDOOF",
"byMove": [
{"move1": "TACKLE_FAST", "move2": "HYPER_FANG", "result": {"estimator": 5.4321}},
],
"cp": 721,
}
],
}
}
]
},
)
pokebattler_proxy = PokebattlerProxy()
raid = Raid("RAID_LEVEL_5_SHADOW", "MEWTWO_SHADOW_FORM")
results = pokebattler_proxy.simulate(raid)
assert results == {"BIDOOF": [MovesetResult("TACKLE_FAST", "HYPER_FANG", 5.4321)]}