Skip to content

Commit

Permalink
Use unique_id in config flow
Browse files Browse the repository at this point in the history
  • Loading branch information
postlund committed Mar 20, 2020
1 parent 11d00df commit 471081a
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 92 deletions.
2 changes: 1 addition & 1 deletion custom_components/apple_tv/.translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
},
"options": {
"step": {
"device_options": {
"init": {
"description": "Configure General Device Settings",
"data": {
"start_off": "Do not turn device on when starting Home Assistant"
Expand Down
4 changes: 2 additions & 2 deletions custom_components/apple_tv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _auth_problem(self):
_LOGGER.debug("Authentication error, reconfigure integration")

name = self.config_entry.data.get(CONF_NAME)
identifier = self.config_entry.data.get(CONF_IDENTIFIER)
identifier = self.config_entry.unique_id

self.hass.components.persistent_notification.create(
"An irrecoverable connection problem occurred when connecting to "
Expand All @@ -265,7 +265,7 @@ def _auth_problem(self):
)

async def _scan(self):
identifier = self.config_entry.data[CONF_IDENTIFIER]
identifier = self.config_entry.unique_id
address = self.config_entry.data[CONF_ADDRESS]
protocol = Protocol(self.config_entry.data[CONF_PROTOCOL])

Expand Down
67 changes: 25 additions & 42 deletions custom_components/apple_tv/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
CONF_TYPE,
)
from homeassistant.core import callback
from homeassistant.data_entry_flow import AbortFlow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import ( # pylint: disable=unused-import
CONF_CREDENTIALS,
from .const import (
CONF_CREDENTIALS_AIRPLAY,
CONF_CREDENTIALS_DMAP,
CONF_CREDENTIALS_MRP,
CONF_IDENTIFIER,
CONF_START_OFF,
DOMAIN,
)
from .const import CONF_CREDENTIALS # pylint: disable=unused-import

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,23 +101,19 @@ def __init__(self):
"""Initialize a new AppleTVConfigFlow."""
self.scan_result = None
self.atv = None
self.identifier = None
self.protocol = None
self.pairing = None
self.credentials = {} # Protocol -> credentials

async def async_step_invalid_credentials(self, info):
"""Handle initial step when updating invalid credentials."""
self.identifier = info.get(CONF_IDENTIFIER)
await self.async_set_unique_id(info.get(CONF_IDENTIFIER))

# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["title_placeholders"] = {"name": info.get(CONF_NAME)}

await self.async_set_unique_id(self.identifier)
self._abort_if_unique_id_configured()

# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["identifier"] = self.identifier
self.context["identifier"] = self.unique_id
return await self.async_step_reconfigure()

async def async_step_reconfigure(self, user_input=None):
Expand All @@ -137,7 +134,7 @@ async def async_step_user(self, user_input=None):
errors = {}
default_suggestion = self._prefill_identifier()
if user_input is not None:
self.identifier = user_input[CONF_IDENTIFIER]
await self.async_set_unique_id(user_input[CONF_IDENTIFIER])
try:
await self.async_find_device()
return await self.async_step_confirm()
Expand All @@ -152,7 +149,7 @@ async def async_step_user(self, user_input=None):
errors["base"] = "unknown"

# Use whatever the user entered as default value
default_suggestion = self.identifier
default_suggestion = self.unique_id

return self.async_show_form(
step_id="user",
Expand All @@ -170,23 +167,22 @@ async def async_step_zeroconf(self, discovery_info):
properties = discovery_info["properties"]

if service_type == "_mediaremotetv._tcp.local.":
self.identifier = properties["UniqueIdentifier"]
identifier = properties["UniqueIdentifier"]
name = properties["Name"]
elif service_type == "_touch-able._tcp.local.":
self.identifier = discovery_info["name"].split(".")[0]
identifier = discovery_info["name"].split(".")[0]
name = properties["CtlN"]
elif service_type == "_appletv-v2._tcp.local.":
self.identifier = discovery_info["name"].split(".")[0]
identifier = discovery_info["name"].split(".")[0]
name = properties["Name"] + " (Home Sharing)"
else:
return self.async_abort(reason="unrecoverable_error")

for flow in self._async_in_progress():
if flow["context"].get("identifier") == self.identifier:
return self.async_abort(reason="already_configured")
await self.async_set_unique_id(identifier, raise_on_progress=True)
self._abort_if_unique_id_configured()

# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["identifier"] = self.identifier
self.context["identifier"] = self.unique_id
self.context["title_placeholders"] = {"name": name}
return await self.async_find_device_wrapper(self.async_step_confirm)

Expand All @@ -211,7 +207,7 @@ async def async_find_device_wrapper(self, next_func, allow_exist=False):
async def async_find_device(self, allow_exist=False):
"""Scan for the selected device to discover services."""
self.scan_result, self.atv = await device_scan(
self.identifier, self.hass.loop, cache=self.scan_result
self.unique_id, self.hass.loop, cache=self.scan_result
)
if not self.atv:
raise DeviceNotFound()
Expand All @@ -220,7 +216,7 @@ async def async_find_device(self, allow_exist=False):

if not allow_exist:
for identifier in self.atv.all_identifiers:
if self._is_already_configured(identifier):
if identifier in self._async_current_ids():
raise DeviceAlreadyConfigured()

# If credentials were found, save them
Expand All @@ -247,8 +243,8 @@ async def async_begin_pairing(self):

# Any more protocols to pair? Else bail out here
if not self.protocol:
await self.async_set_unique_id(self.atv.main_service().identifier)
return await self._async_get_entry(
self.atv.identifier,
self.atv.main_service().protocol,
self.atv.name,
self.credentials,
Expand Down Expand Up @@ -295,6 +291,8 @@ async def async_step_pair_with_pin(self, user_input=None):
return await self.async_begin_pairing()
except exceptions.PairingError:
errors["base"] = "auth"
except AbortFlow:
raise
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
Expand Down Expand Up @@ -340,15 +338,14 @@ async def async_step_service_problem(self, user_input=None):

async def async_step_import(self, info):
"""Import device from configuration file."""
self.identifier = info.get(CONF_IDENTIFIER)
await self.async_set_unique_id(info.get(CONF_IDENTIFIER))

_LOGGER.debug("Importing device with identifier %s", self.identifier)
_LOGGER.debug("Importing device with identifier %s", self.unique_id)
creds = {
CREDENTIAL_MAPPING[prot]: creds
for prot, creds in info.get(CONF_CREDENTIALS).items()
}
return await self._async_get_entry(
info.get(CONF_IDENTIFIER),
const.Protocol[info.get(CONF_PROTOCOL)],
info.get(CONF_NAME),
creds,
Expand All @@ -357,24 +354,19 @@ async def async_step_import(self, info):
)

async def _async_get_entry(
self, identifier, protocol, name, credentials, address, is_import=False
self, protocol, name, credentials, address, is_import=False
):
if not is_valid_credentials(credentials):
return self.async_abort(reason="invalid_config")

data = {
CONF_IDENTIFIER: identifier,
CONF_PROTOCOL: protocol.value,
CONF_NAME: name,
CONF_CREDENTIALS: credentials,
CONF_ADDRESS: str(address),
}

config_entry = self._get_config_entry(identifier)
if config_entry:
config_entry.data.update(data)
self.hass.config_entries.async_update_entry(config_entry)
return self.async_abort(reason="updated_configuration")
self._abort_if_unique_id_configured(updates=data)

title = name + (" (import from configuration.yaml)" if is_import else "")
return self.async_create_entry(title=title, data=data)
Expand All @@ -395,26 +387,17 @@ def _devices_str(self):
[
f"`{atv.name} ({atv.address})`"
for atv in self.scan_result
if not self._is_already_configured(atv.identifier)
if atv.identifier not in self._async_current_ids()
]
)

def _prefill_identifier(self):
# Return identifier (address) of one device that has not been paired with
for atv in self.scan_result:
if not self._is_already_configured(atv.identifier):
if atv.identifier not in self._async_current_ids():
return str(atv.address)
return ""

def _is_already_configured(self, identifier):
return self._get_config_entry(identifier) is not None

def _get_config_entry(self, identifier):
for entry in self._async_current_entries():
if entry.data[CONF_IDENTIFIER] == identifier:
return entry
return None


class AppleTVOptionsFlow(config_entries.OptionsFlow):
"""Handle Apple TV options."""
Expand All @@ -431,7 +414,7 @@ async def async_step_init(self, user_input=None):
return self.async_create_entry(title="", data=self.options)

return self.async_show_form(
step_id="device_options",
step_id="init",
data_schema=vol.Schema(
{
vol.Optional(
Expand Down
4 changes: 2 additions & 2 deletions custom_components/apple_tv/media_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from homeassistant.core import callback
import homeassistant.util.dt as dt_util

from .const import CONF_IDENTIFIER, DOMAIN
from .const import DOMAIN

_LOGGER = logging.getLogger(__name__)

Expand All @@ -51,7 +51,7 @@

async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Apple TV media player based on a config entry."""
identifier = config_entry.data[CONF_IDENTIFIER]
identifier = config_entry.unique_id
name = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTvDevice(name, identifier, manager)])
Expand Down
15 changes: 2 additions & 13 deletions custom_components/apple_tv/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,16 @@
from homeassistant.const import CONF_NAME
from homeassistant.core import callback

from .const import CONF_IDENTIFIER, DOMAIN
from .const import DOMAIN

_LOGGER = logging.getLogger(__name__)

PARALLEL_UPDATES = 0


async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
"""Set up the Apple TV remote platform."""
if not discovery_info:
return

identifier = discovery_info[CONF_IDENTIFIER]
name = discovery_info[CONF_NAME]
manager = hass.data[DOMAIN][identifier]
async_add_entities([AppleTVRemote(name, identifier, manager)])


async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Apple TV remote based on a config entry."""
identifier = config_entry.data[CONF_IDENTIFIER]
identifier = config_entry.unique_id
name = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTVRemote(name, identifier, manager)])
Expand Down
2 changes: 1 addition & 1 deletion custom_components/apple_tv/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
},
"options": {
"step": {
"device_options": {
"init": {
"description": "Configure General Device Settings",
"data": {
"start_off": "Do not turn device on when starting Home Assistant"
Expand Down
6 changes: 3 additions & 3 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ if grep docker /proc/1/cgroup -qa; then
source venv/bin/activate
python setup.py develop

cmd="pytest -v --cov-report=term-missing --cov-report=html --cov homeassistant.components.apple_tv.config_flow --disable-warnings tests/components/apple_tv"
cmd="pytest -vv --cov-report=term-missing --cov-report=html --cov homeassistant.components.apple_tv.config_flow --disable-warnings tests/components/apple_tv"
if [ "$1" = "loop" ]; then
cmd="while true; do $cmd; echo Press enter...; read t; if [ ! "\$t" = '' ]; then break; fi; done"
elif [ $# -gt 0 ]; then
Expand All @@ -18,6 +18,7 @@ if grep docker /proc/1/cgroup -qa; then

./script/gen_requirements_all.py
python3 -m script.hassfest
pip install tox
pip install -qqq -r requirements_test_all.txt
pip uninstall -y asyncio typing
sed -i '/apple_tv/d' .coveragerc
Expand All @@ -28,10 +29,9 @@ else
-v $COMP_DIR/../home-assistant:/ha \
-v $COMP_DIR/../conf:/ha/conf \
-v $COMP_DIR/htmlcov:/ha/htmlcov \
-v $COMP_DIR/run.sh:/ha/run.sh \
-v $COMP_DIR/custom_components/apple_tv:/ha/homeassistant/components/apple_tv \
-v $COMP_DIR/tests/apple_tv:/ha/tests/components/apple_tv \
-v $COMP_DIR/run.sh:/ha/run.sh \
hadev:latest \
bash -c "$0 $*"
fi

26 changes: 20 additions & 6 deletions tests/apple_tv/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pyatv import conf, interface
from pyatv.const import Protocol

from homeassistant.data_entry_flow import AbortFlow


class MockPairingHandler(interface.PairingHandler):
"""Mock for PairingHandler in pyatv."""
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(self, flow):
self.flow = flow
self.name = None
self.result = None
self.exception = None

def __getattr__(self, attr):
"""Return correct action method dynamically based on name."""
Expand All @@ -91,15 +94,19 @@ def __getattr__(self, attr):
gives_type, gives_name = name.split("_", 1)
return partial(getattr(self, "_" + gives_type), gives_name)

async def _init(self, **data):
async def _init(self, has_input=True, **user_input):
args = {**user_input} if has_input else None
self.result = await self.flow.hass.config_entries.flow.async_init(
"apple_tv", data={**data}, context={"source": self.name}
"apple_tv", data=args, context={"source": self.name}
)
return self

async def _step(self, has_input=True, **user_input):
args = {**user_input} if has_input else None
self.result = await getattr(self.flow, "async_step_" + self.name)(args)
try:
self.result = await getattr(self.flow, "async_step_" + self.name)(args)
except AbortFlow as ex:
self.exception = ex
return self

def _form(self, step_id, **kwargs):
Expand All @@ -108,13 +115,20 @@ def _form(self, step_id, **kwargs):
for key, value in kwargs.items():
assert self.result[key] == value

def _create_entry(self, entry):
def _create_entry(self, entry, unique_id=None):
assert self.result["type"] == "create_entry"
assert self.result["data"] == entry

if unique_id:
print(unique_id, self.flow.unique_id)
assert self.flow.unique_id == unique_id

def _abort(self, reason):
assert self.result["type"] == "abort"
assert self.result["reason"] == reason
if self.result:
assert self.result["type"] == "abort"
assert self.result["reason"] == reason
else:
assert self.exception.reason, reason


def create_conf(name, address, *services):
Expand Down
Loading

0 comments on commit 471081a

Please sign in to comment.