diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..9cbc194 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = True +source = serverauditor_sshconfig +include = *.py + +[report] +precision = 2 +show_missing = True diff --git a/.gitignore b/.gitignore index 1481bfb..112cdb9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ .DS_Store .idea +.tox +.coverage *.pyc *.pyo diff --git a/.noserc b/.noserc new file mode 100644 index 0000000..45ef7ab --- /dev/null +++ b/.noserc @@ -0,0 +1,2 @@ +[nosetests] +with-coverage=1 \ No newline at end of file diff --git a/.prospector.yaml b/.prospector.yaml new file mode 100644 index 0000000..92d81be --- /dev/null +++ b/.prospector.yaml @@ -0,0 +1,16 @@ +inherits: + - strictness_veryhigh + - full_pep8 + - doc_warnings + +ignore-paths: + - .git + +pylint: + options: + max-parents: 12 + disable: + # - broad-except + # - pointless-except + # - bad-super-call + # - nonstandard-exception diff --git a/serverauditor_sshconfig/account/commands.py b/serverauditor_sshconfig/account/commands.py index 578055c..09773ab 100644 --- a/serverauditor_sshconfig/account/commands.py +++ b/serverauditor_sshconfig/account/commands.py @@ -9,16 +9,20 @@ from .managers import AccountManager -class LoginCommand(AbstractCommand): +class BaseAccountCommand(AbstractCommand): + + def __init__(self, app, app_args, cmd_name=None): + super(BaseAccountCommand, self).__init__(app, app_args, cmd_name) + self.manager = AccountManager(self.config) + + +class LoginCommand(BaseAccountCommand): """Sign into serverauditor cloud.""" def prompt_username(self): return six.moves.input("Serverauditor's username: ") - def prompt_password(self): - return getpass("Serverauditor's password: ") - def get_parser(self, prog_name): parser = super(LoginCommand, self).get_parser(prog_name) parser.add_argument('-u', '--username', metavar='USERNAME') @@ -27,14 +31,13 @@ def get_parser(self, prog_name): return parser def take_action(self, parsed_args): - manager = AccountManager(self.app.NAME) username = parsed_args.username or self.prompt_username() password = parsed_args.password or self.prompt_password() - manager.login(username, password) + self.manager.login(username, password) self.log.info('Sign into serverauditor cloud.') -class LogoutCommand(AbstractCommand): +class LogoutCommand(BaseAccountCommand): """Sign out serverauditor cloud.""" @@ -44,6 +47,5 @@ def get_parser(self, prog_name): return parser def take_action(self, parsed_args): - manager = AccountManager(self.app.NAME) - manager.logout() + self.manager.logout() self.log.info('Sign out serverauditor cloud.') diff --git a/serverauditor_sshconfig/account/managers.py b/serverauditor_sshconfig/account/managers.py index 71373fd..9ae5d97 100644 --- a/serverauditor_sshconfig/account/managers.py +++ b/serverauditor_sshconfig/account/managers.py @@ -5,14 +5,13 @@ License BSD, see LICENSE for more details. """ -from ..core.settings import Config from ..core.api import API class AccountManager(object): - def __init__(self, application_name): - self.config = Config(application_name) + def __init__(self, config): + self.config = config self.api = API() def login(self, username, password): diff --git a/serverauditor_sshconfig/cloud/__init__.py b/serverauditor_sshconfig/cloud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serverauditor_sshconfig/cloud/commands.py b/serverauditor_sshconfig/cloud/commands.py new file mode 100644 index 0000000..118dd89 --- /dev/null +++ b/serverauditor_sshconfig/cloud/commands.py @@ -0,0 +1,50 @@ +from base64 import b64decode +from ..core.commands import AbstractCommand +from .controllers import ApiController +from .cryptor import RNCryptor + + +class PushCommand(AbstractCommand): + + """Push data to Serverauditor cloud.""" + + def get_parser(self, prog_name): + parser = super(PushCommand, self).get_parser(prog_name) + parser.add_argument( + '-s', '--silent', action='store_true', + help='Do not produce any interactions.' + ) + parser.add_argument( + '-S', '--strategy', metavar='STRATEGY_NAME', + help='Force to use specific strategy to merge data.' + ) + return parser + + def take_action(self, parsed_args): + self.log.info('Push data to Serverauditor cloud.') + + +class PullCommand(AbstractCommand): + + """Pull data from Serverauditor cloud.""" + + def get_parser(self, prog_name): + parser = super(PullCommand, self).get_parser(prog_name) + parser.add_argument( + '-s', '--strategy', metavar='STRATEGY_NAME', + help='Force to use specific strategy to merge data.' + ) + return parser + + def take_action(self, parsed_args): + encryption_salt = b64decode(self.config.get('User', 'salt')) + hmac_salt = b64decode(self.config.get('User', 'hmac_salt')) + password = self.prompt_password() + cryptor = RNCryptor() + cryptor.password = password + cryptor.encryption_salt = encryption_salt + cryptor.hmac_salt = hmac_salt + controller = ApiController(self.storage, self.config, cryptor) + with self.storage: + controller.get_bulk() + self.log.info('Pull data from Serverauditor cloud.') diff --git a/serverauditor_sshconfig/cloud/controllers.py b/serverauditor_sshconfig/cloud/controllers.py new file mode 100644 index 0000000..c71775d --- /dev/null +++ b/serverauditor_sshconfig/cloud/controllers.py @@ -0,0 +1,78 @@ +from .serializers import BulkSerializer +from ..core.api import API + + +class CryptoController(object): + + def __init__(self, cryptor): + self.cryptor = cryptor + + def _mutate_fields(self, model, mutator): + for i in model.crypto_fields: + crypto_field = getattr(model, i) + if crypto_field: + setattr(model, i, mutator(crypto_field)) + return model + + def encrypt(self, model): + return self._mutate_fields(model, self.cryptor.encrypt) + + def decrypt(self, model): + return self._mutate_fields(model, self.cryptor.decrypt) + + +class ApiController(object): + + mapping = dict( + bulk=dict(url='v2/terminal/bulk/', serializer=BulkSerializer) + ) + + def __init__(self, storage, config, cryptor): + self.config = config + username = self.config.get('User', 'username') + apikey = self.config.get('User', 'apikey') + assert username + assert apikey + self.api = API(username, apikey) + self.storage = storage + self.crypto_controller = CryptoController(cryptor) + + def _get(self, mapped): + serializer = mapped['serializer']( + storage=self.storage, crypto_controller=self.crypto_controller + ) + response = self.api.get(mapped['url']) + + model = serializer.to_model(response) + return model + + def get_bulk(self): + mapped = self.mapping['bulk'] + model = self._get(mapped) + self.config.set('CloudSynchronization', 'last_synced', + model['last_synced']) + self.config.write() + + def _post(self, mapped, request_model): + request_model = request_model + serializer = mapped['serializer']( + storage=self.storage, crypto_controller=self.crypto_controller + ) + + payload = serializer.to_payload(request_model) + response = self.api.post(mapped['url'], payload) + + response_model = serializer.to_models(response) + return response_model + + def post_bulk(self): + mapped = self.mapping['bulk'] + model = {} + model['last_synced'] = self.config.get( + 'CloudSynchronization', 'last_synced' + ) + assert model['last_synced'] + out_model = self._post(mapped, model) + self.config.set('CloudSynchronization', 'last_synced', + out_model['last_synced']) + self.config.write() diff --git a/serverauditor_sshconfig/core/cryptor.py b/serverauditor_sshconfig/cloud/cryptor.py similarity index 84% rename from serverauditor_sshconfig/core/cryptor.py rename to serverauditor_sshconfig/cloud/cryptor.py index 846ca39..a4e9b56 100644 --- a/serverauditor_sshconfig/core/cryptor.py +++ b/serverauditor_sshconfig/cloud/cryptor.py @@ -16,7 +16,7 @@ from Crypto.Protocol import KDF from Crypto import Random -from .utils import bchr, bord, to_bytes, to_str +from ..core.utils import bchr, bord, to_bytes, to_str class CryptorException(Exception): @@ -161,34 +161,3 @@ def _pbkdf2(self, password, salt, iterations=10000, key_length=32): ## passlib version -- the fastest version # from passlib.utils.pbkdf2 import pbkdf2 # return pbkdf2(password, salt, iterations, key_length) - - -def main(): - from time import time - - cryptor = RNCryptor() - cryptor.encryption_salt = b'1' * 8 - cryptor.hmac_salt = b'1' * 8 - - passwords = 'p@s$VV0Rd', 'пароль' - texts = 'www.crystalnix.com', 'текст', '', '1' * 16, '2' * 15, '3' * 17 - - for password in passwords: - cryptor.password = password - for text in texts: - print('text: "{}"'.format(text)) - - s = time() - encrypted_data = cryptor.encrypt(text) - print('encrypted {}: "{}"'.format(time() - s, encrypted_data)) - - s = time() - decrypted_data = cryptor.decrypt(encrypted_data) - print('decrypted {}: "{}"\n'.format(time() - s, decrypted_data)) - - assert text == decrypted_data - - -if __name__ == '__main__': - - main() \ No newline at end of file diff --git a/serverauditor_sshconfig/cloud/group/__init__.py b/serverauditor_sshconfig/cloud/group/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serverauditor_sshconfig/cloud/host/__init__.py b/serverauditor_sshconfig/cloud/host/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serverauditor_sshconfig/models/__init__.py b/serverauditor_sshconfig/cloud/models.py similarity index 59% rename from serverauditor_sshconfig/models/__init__.py rename to serverauditor_sshconfig/cloud/models.py index f5164c5..2d8f4fe 100644 --- a/serverauditor_sshconfig/models/__init__.py +++ b/serverauditor_sshconfig/cloud/models.py @@ -5,12 +5,14 @@ class Tag(Model): fields = {'label'} set_name = 'tag_set' + crypto_fields = fields class SshKey(Model): fields = {'label', 'passphrase', 'private_key', 'public_key'} set_name = 'sshkeycrypt_set' + crypto_fields = fields class SshIdentity(Model): @@ -20,7 +22,7 @@ class SshIdentity(Model): mapping = { 'ssh_key': Mapping(SshKey, many=False), } - + crypto_fields = {'label', 'username', 'password'} class SshConfig(Model): @@ -38,6 +40,7 @@ class Group(Model): mapping = { 'ssh_config': Mapping(SshConfig, many=False), } + crypto_fields = {'label',} Group.mapping['parent_group'] = Mapping(Group, many=False) @@ -45,11 +48,34 @@ class Group(Model): class Host(Model): - fields = {'label', 'address', 'group', 'tags', 'address', 'ssh_config'} + fields = {'label', 'group', # 'tags', + 'address', 'ssh_config'} set_name = 'host_set' mapping = { 'ssh_config': Mapping(SshConfig, many=False), - 'tags': Mapping(Tag, many=True), + # 'tags': Mapping(Tag, many=True), + } + crypto_fields = {'label', 'address'} + + +class Host(Model): + + fields = {'label', 'group', 'address', 'ssh_config'} + set_name = 'host_set' + mapping = { + 'ssh_config': Mapping(SshConfig, many=False), + # 'tags': Mapping(Tag, many=True), + } + crypto_fields = {'label', 'address'} + + +class TagHost(Model): + + fields = {'host', 'tag'} + set_name = 'taghost_set' + mapping = { + 'host': Mapping(Host, many=False), + 'tag': Mapping(Tag, many=False), } @@ -61,3 +87,4 @@ class PFRule(Model): mapping = { 'host': Mapping(Host, many=False), } + crypto_fields = {'label', 'bound_address', 'hostname'} diff --git a/serverauditor_sshconfig/cloud/pfrlule/__init__.py b/serverauditor_sshconfig/cloud/pfrlule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serverauditor_sshconfig/cloud/serializers.py b/serverauditor_sshconfig/cloud/serializers.py new file mode 100644 index 0000000..45f310a --- /dev/null +++ b/serverauditor_sshconfig/cloud/serializers.py @@ -0,0 +1,211 @@ +"""Serializers (read controllers) is like django rest framework serializers.""" + +import six +import abc +from collections import OrderedDict +from operator import attrgetter, itemgetter +from ..core.models import RemoteInstance +from ..core.exceptions import DoesNotExistException +from .models import ( + Host, Group, + Tag, SshKey, + SshIdentity, SshConfig, + Group, Host, + PFRule, TagHost, +) + + +def zip_model_fields(model, field_getter=None): + field_getter = field_getter or attrgetter(model.fields) + return zip(model.fields, field_getter(model)) + + +@six.add_metaclass(abc.ABCMeta) +class Serializer(object): + + def __init__(self, storage): + assert storage + self.storage = storage + + @abc.abstractmethod + def to_model(self, payload): + """Convert REST API payload to Application models.""" + + @abc.abstractmethod + def to_payload(self, model): + """Convert Application models to REST API payload.""" + + +class BulkEntryBaseSerializer(Serializer): + + def __init__(self, model_class, **kwargs): + super(BulkEntryBaseSerializer, self).__init__(**kwargs) + assert model_class + self.model_class = model_class + + +class BulkPrimaryKeySerializer(BulkEntryBaseSerializer): + + to_model_mapping = { + int: int, + dict: itemgetter('id') + } + + def id_from_payload(self, payload): + return self.to_model_mapping[type(payload)](payload) + + def to_model(self, payload): + if not payload: + return None + + remote_instance_id = self.id_from_payload(payload) + + model = self.storage.get( + self.model_class, + **{'remote_instance.id': remote_instance_id} + ) + return model + + def to_payload(self, model): + if not model: + return None + if model.remote_instance: + return model.remote_instance.id + else: + return '{model.set_name}/{model.id}'.format(model=model) + + +class BulkEntrySerializer(BulkPrimaryKeySerializer): + + def __init__(self, **kwargs): + super(BulkEntrySerializer, self).__init__(**kwargs) + self.attrgetter = attrgetter(*self.model_class.fields) + self.remote_instance_attrgetter = attrgetter(*RemoteInstance.fields) + + def update_model_fields(self, model, payload): + for i in model.fields: + mapping = model.mapping.get(i) + if mapping: + serializer = BulkPrimaryKeySerializer( + storage=self.storage, model_class=mapping.model + ) + field = serializer.to_model(payload[i]) + else: + field = payload[i] + setattr(model, i, field) + return model + + def create_remote_instance(self, payload): + remote_instance = RemoteInstance() + for i in RemoteInstance.fields: + setattr(remote_instance, i, payload.pop(i)) + return remote_instance + + def get_or_initialize_model(self, payload): + try: + model = super(BulkEntrySerializer, self).to_model(payload) + except DoesNotExistException: + remote_instance = self.create_remote_instance(payload) + model = self.model_class() + model.remote_instance = remote_instance + model.update( + ((k, v) for k, v in payload.items() if k in model.fields) + ) + + model.id = payload.get('local_id', model.id) + return model + + def to_model(self, payload): + model = self.get_or_initialize_model(payload) + model = self.update_model_fields(model, payload) + return model + + def to_payload(self, model): + payload = dict(zip_model_fields(model, self.attrgetter)) + if model.remote_instance: + zipped_remote_instance = zip_model_fields( + model.remote_instance, self.remote_instance_attrgetter + ) + payload.update(zipped_remote_instance) + payload['local_id'] = model.id + for field, mapping in model.mapping.items(): + serializer = BulkPrimaryKeySerializer( + storage=self.storage, model_class=mapping.model + ) + fk_payload = serializer.to_payload(getattr(model, field)) + payload[field] = fk_payload + return payload + +class CryptoBulkEntrySerializer(BulkEntrySerializer): + + def __init__(self, crypto_controller, **kwargs): + super(CryptoBulkEntrySerializer, self).__init__(**kwargs) + self.crypto_controller = crypto_controller + + def to_model(self, payload): + model = super(CryptoBulkEntrySerializer, self).to_model(payload) + return self.crypto_controller.decrypt(model) + + def to_payload(self, model): + encrypted_model = self.crypto_controller.encrypt(model) + return super(CryptoBulkEntrySerializer, self).to_payload( + encrypted_model) + + +class BulkSerializer(Serializer): + + child_serializer_class = CryptoBulkEntrySerializer + supported_models = ( + SshKey, SshIdentity, SshConfig, Tag, Group, Host, PFRule, TagHost + ) + + def create_child_serializer(self, model_class): + return self.child_serializer_class( + model_class=model_class, storage=self.storage, + crypto_controller=self.crypto_controller + ) + + def __init__(self, crypto_controller, **kwargs): + super(BulkSerializer, self).__init__(**kwargs) + self.crypto_controller = crypto_controller + self.mapping = OrderedDict(( + (i.set_name, self.create_child_serializer(i)) + for i in self.supported_models + )) + + def process_model_entries(self, updated, deleted): + for i in updated: + self.storage.save(i) + for i in deleted: + self.storage.delete(i) + + def to_model(self, payload): + models = {} + models['last_synced'] = payload.pop('now') + models['deleted_sets'] = {} + deleted_sets = payload.pop('deleted_sets') + for set_name, serializer in self.mapping.items(): + models[set_name] = [serializer.to_model(i) for i in payload[set_name]] + serializer = BulkPrimaryKeySerializer( + storage=self.storage, model_class=serializer.model_class + ) + deleted_set = [] + for i in deleted_sets[set_name]: + try: + deleted_set.append(serializer.to_model(i)) + except DoesNotExistException: + pass + models['deleted_sets'][set_name] = deleted_set + + self.process_model_entries( + models[set_name], models['deleted_sets'][set_name] + ) + return models + + def to_payload(self, model): + payload = {} + payload['last_synced'] = model.pop('last_synced') + deleted_sets = payload.pop('deleted_sets') + for set_name, serializer in self.mapping.items(): + payload[set_name] = [serializer.to_model(i) for i in payload[set_name]] + raise RuntimeError('Need implement deleted_sets showing.') diff --git a/serverauditor_sshconfig/cloud/sshkeycrypt/__init__.py b/serverauditor_sshconfig/cloud/sshkeycrypt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/serverauditor_sshconfig/core/api.py b/serverauditor_sshconfig/core/api.py index 8fe371d..fb37ec3 100644 --- a/serverauditor_sshconfig/core/api.py +++ b/serverauditor_sshconfig/core/api.py @@ -67,12 +67,12 @@ def post(self, endpoint, data): assert response.status_code == 201 return response.json() - def get(self, endpoint, data): + def get(self, endpoint): response = requests.get(self.request_url(endpoint), auth=self.auth) assert response.status_code == 200 return response.json() - def delete(self, endpoint, data): + def delete(self, endpoint): response = requests.delete(self.request_url(endpoint), auth=self.auth) assert response.status_code in (200, 204) return response.json() diff --git a/serverauditor_sshconfig/core/commands.py b/serverauditor_sshconfig/core/commands.py index 7e2a46d..f7a69bc 100644 --- a/serverauditor_sshconfig/core/commands.py +++ b/serverauditor_sshconfig/core/commands.py @@ -1,16 +1,28 @@ # coding: utf-8 import logging +import getpass from cliff.command import Command from cliff.lister import Lister +from .settings import Config, ApplicationStorage -class AbstractCommand(Command): +class PasswordPromptMixin(object): + def prompt_password(self): + return getpass.getpass("Serverauditor's password:") + + +class AbstractCommand(PasswordPromptMixin, Command): "Abstract Command with log." log = logging.getLogger(__name__) + def __init__(self, app, app_args, cmd_name=None): + super(AbstractCommand, self).__init__(app, app_args, cmd_name) + self.config = Config(self.app.NAME) + self.storage = ApplicationStorage(self.app.NAME) + def get_parser(self, prog_name): parser = super(AbstractCommand, self).get_parser(prog_name) parser.add_argument('--log-file', help="Path to log file.") @@ -38,6 +50,11 @@ def get_parser(self, prog_name): class ListCommand(Lister): + def __init__(self, app, app_args, cmd_name=None): + super(ListCommand, self).__init__(app, app_args, cmd_name) + self.config = Config(self.app.NAME) + self.storage = ApplicationStorage(self.app.NAME) + def get_parser(self, prog_name): parser = super(ListCommand, self).get_parser(prog_name) parser.add_argument( diff --git a/serverauditor_sshconfig/core/exceptions.py b/serverauditor_sshconfig/core/exceptions.py new file mode 100644 index 0000000..6358a2b --- /dev/null +++ b/serverauditor_sshconfig/core/exceptions.py @@ -0,0 +1,2 @@ +class DoesNotExistException(Exception): + pass diff --git a/serverauditor_sshconfig/core/models.py b/serverauditor_sshconfig/core/models.py index 6e1cafc..2a47423 100644 --- a/serverauditor_sshconfig/core/models.py +++ b/serverauditor_sshconfig/core/models.py @@ -1,5 +1,3 @@ -import six -import abc import copy from collections import namedtuple @@ -7,15 +5,14 @@ Mapping = namedtuple('Mapping', ('model', 'many')) -class Model(dict): +class AbstractModel(dict): fields = set() - - __mandatory_fields = {'id', 'remote_instance'} + _mandatory_fields = set() @classmethod def allowed_feilds(cls): - return tuple(cls.fields.union(cls.__mandatory_fields)) + return tuple(cls.fields.union(cls._mandatory_fields)) @classmethod def _validate_attr(cls, name): @@ -39,21 +36,37 @@ def copy(self): def __copy__(self): newone = type(self)() - newone.__dict__.update(self.__dict__) + newone.update(self) return newone def __deepcopy__(self, requesteddeepcopy): return type(self)(copy.deepcopy(super(Model, self))) + +class Model(AbstractModel): + + _mandatory_fields = {'id', 'remote_instance'} + + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + is_need_to_patch_remote_instance = ( + self.remote_instance and + not isinstance(self.remote_instance, RemoteInstance) + ) + if is_need_to_patch_remote_instance: + self.remote_instance = RemoteInstance(self.remote_instance) + # set_name = '' # """Key name in Application Storage.""" - # mapping = {} - # """Foreign key mapping - Mapping instances per field_name.""" + mapping = {} + """Foreign key mapping - Mapping instances per field_name.""" + crypto_fields = {} + """Set of fields for enrpyption and decryption on cloud.""" id_name = 'id' """Name of field to be used as identificator.""" -class RemoteInstance(Model): +class RemoteInstance(AbstractModel): fields = {'id', 'updated_at'} diff --git a/serverauditor_sshconfig/core/settings.py b/serverauditor_sshconfig/core/settings.py index 9018097..5bb7872 100644 --- a/serverauditor_sshconfig/core/settings.py +++ b/serverauditor_sshconfig/core/settings.py @@ -10,12 +10,25 @@ from uuid import uuid4 from collections import OrderedDict from .storage import PersistentDict +from .exceptions import DoesNotExistException def expand_and_format_path(paths, **kwargs): return [os.path.expanduser(i.format(**kwargs)) for i in paths] +def tupled_attrgetter(*items): + def g(obj): + return tuple(resolve_attr(obj, attr) for attr in items) + return g + + +def resolve_attr(obj, attr): + for name in attr.split("."): + obj = getattr(obj, name) + return obj + + class Config(object): paths = ['~/.{application_name}'] @@ -40,7 +53,7 @@ def touch_files(self): pass def get(self, *args, **kwargs): - self.config.get(*args, **kwargs) + return self.config.get(*args, **kwargs) def set(self, section, option, value): if not self.config.has_section(section): @@ -71,15 +84,15 @@ def __init__(self, storage): def __call__(self, model): """:param core.models.Model model: generate and set id for this Model.""" assert not getattr(model, model.id_name) - uuid = uuid4().int - setattr(model, model.id_name, uuid) - return uuid + identificator = uuid4().time_low + setattr(model, model.id_name, identificator) + return identificator class ApplicationStorage(object): path = '~/.{application_name}.storage' - defaultstorage = OrderedDict + defaultstorage = list def __init__(self, application_name, **kwargs): self._path = expand_and_format_path( @@ -91,29 +104,27 @@ def __init__(self, application_name, **kwargs): def generate_id(self, model): return self.id_generator(model) - def get_all(self, model_class): - assert isinstance(model_class, type) - name = model_class.set_name - dict_data = self.driver.setdefault(name, self.defaultstorage()) - if dict_data: - model_data = self.defaultstorage( - ((k, model_class(v)) for k, v in dict_data.items()) - ) - else: - model_data = dict_data - return model_data + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.driver.sync() def save_mapped_fields(self, model): + def save_instance(model): + if isinstance(model, six.integer_types): + return model + return self.save(model).id + def sub_save(field, mapping): submodel = getattr(model, field) - if not submodel: - return submodel if not mapping.many: - saved_submodel = self.save(submodel).id - else: - saved_submodel = [self.save(submodel).id for i in submodel] + saved_submodel = submodel and save_instance(submodel) + else: + submodel = submodel or [] + saved_submodel = [save_instance(i) for i in submodel] return saved_submodel model_copy = model.copy() @@ -139,11 +150,11 @@ def save(self, model): def create(self, model): assert not getattr(model, model.id_name) - models = self.get_all(type(model)) model_with_saved_subs = self.save_mapped_fields(model) - models[self.generate_id(model)] = model_with_saved_subs + model.id = self.generate_id(model_with_saved_subs) + models = self.get_all(type(model)) + models.append(model_with_saved_subs) self.driver[model.set_name] = models - self.driver.sync() return model_with_saved_subs def update(self, model): @@ -151,23 +162,44 @@ def update(self, model): assert identificator self.save_mapped_fields(model) - models = self.get_all(type(model)) model_with_saved_subs = self.save_mapped_fields(model) - models[identificator] = model_with_saved_subs + + self.delete(model) + models = self.get_all(type(model)) + models.append(model_with_saved_subs) self.driver[model.set_name] = models - self.driver.sync() def get(self, model_class, **kwargs): assert isinstance(model_class, type) + assert kwargs models = self.get_all(model_class) - identificator = kwargs.get(model_class.id_name) - return model_class(models[identificator]) + filter_keys = tuple(i[0] for i in kwargs.items()) + filter_values = tuple(i[1] for i in kwargs.items()) + getter = tupled_attrgetter(*filter_keys) + founded_models = [ + i for i in models if getter(i) == filter_values + ] + if not founded_models: + raise DoesNotExistException + assert len(founded_models) == 1 + return model_class(founded_models[0]) + + def get_all(self, model_class): + assert isinstance(model_class, type) + name = model_class.set_name + data = self.driver.setdefault(name, self.defaultstorage()) + models = self.defaultstorage( + (model_class(i) for i in data) + ) + return models def delete(self, model): identificator = getattr(model, model.id_name) assert identificator - models = self[model.set_name] - models.pop(identificator) + models = self.get_all(type(model)) + for index, model in enumerate(models): + if model.id == identificator: + models.pop(index) + break self.driver[model.set_name] = models - self.driver.sync() diff --git a/serverauditor_sshconfig/core/storage.py b/serverauditor_sshconfig/core/storage.py index 0fa3b34..b45abe2 100644 --- a/serverauditor_sshconfig/core/storage.py +++ b/serverauditor_sshconfig/core/storage.py @@ -126,12 +126,12 @@ def dump(self, fileobj): try: DRIVERS[self.format].dump(fileobj, self) except KeyError: - raise NotImplementedError('Unknown format: ' + repr(format)) + raise NotImplementedError('Unknown format: ' + repr(self.format)) def load(self, fileobj): for loader in DRIVERS.values(): try: return self.update(loader.load(fileobj)) - except Exception as e: + except Exception: pass raise ValueError('File not in a supported format') diff --git a/serverauditor_sshconfig/handlers.py b/serverauditor_sshconfig/handlers.py index 39a5597..0b903a4 100644 --- a/serverauditor_sshconfig/handlers.py +++ b/serverauditor_sshconfig/handlers.py @@ -1,8 +1,7 @@ # coding: utf-8 from operator import attrgetter from .core.commands import AbstractCommand, DetailCommand, ListCommand -from .core.settings import ApplicationStorage -from .models import Host, SshConfig, SshIdentity, SshKey, Tag, Group +from .cloud.models import Host, SshConfig, SshIdentity, SshKey, Tag, Group @@ -102,8 +101,6 @@ def create_host(self, parsed_args): if parsed_args.generate_key: pass # generate SshKey - storage = ApplicationStorage(self.app.NAME) - identity = SshIdentity() identity.username = parsed_args.username identity.password = parsed_args.password @@ -117,7 +114,8 @@ def create_host(self, parsed_args): host.address = parsed_args.address host.ssh_config = config - storage.save(host) + with self.storage: + self.storage.save(host) return host def take_action(self, parsed_args): @@ -145,12 +143,10 @@ def get_parser(self, prog_name): return parser def take_action(self, parsed_args): - storage = ApplicationStorage(self.app.NAME) - hosts = storage.get_all(Host) + hosts = self.storage.get_all(Host) fields = Host.allowed_feilds() getter = attrgetter(*fields) - import pudb;pudb.set_trace() - return fields, [getter(i) for i in hosts.values()] + return fields, [getter(i) for i in hosts] class GroupCommand(DetailCommand): @@ -265,46 +261,6 @@ def take_action(self, parsed_args): self.log.info('Tag objects.') -class PushCommand(AbstractCommand): - - """Push data to Serverauditor cloud.""" - - def get_parser(self, prog_name): - parser = super(PushCommand, self).get_parser(prog_name) - parser.add_argument( - '-s', '--silent', action='store_true', - help='Do not produce any interactions.' - ) - parser.add_argument( - '-S', '--strategy', metavar='STRATEGY_NAME', - help='Force to use specific strategy to merge data.' - ) - return parser - - def take_action(self, parsed_args): - self.log.info('Push data to Serverauditor cloud.') - - -class PullCommand(AbstractCommand): - - """Pull data from Serverauditor cloud.""" - - def get_parser(self, prog_name): - parser = super(PullCommand, self).get_parser(prog_name) - parser.add_argument( - '-s', '--silent', action='store_true', - help='Do not produce any interactions.' - ) - parser.add_argument( - '-S', '--strategy', metavar='STRATEGY_NAME', - help='Force to use specific strategy to merge data.' - ) - return parser - - def take_action(self, parsed_args): - self.log.info('Pull data from Serverauditor cloud.') - - class InfoCommand(AbstractCommand): """Show info about host or group.""" diff --git a/serverauditor_sshconfig/sa_export.py b/serverauditor_sshconfig/sa_export.py deleted file mode 100644 index a309315..0000000 --- a/serverauditor_sshconfig/sa_export.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -""" -Copyright (c) 2013 Crystalnix. -License BSD, see LICENSE for more details. -""" - -import sys - -from serverauditor_sshconfig.core.application import SSHConfigApplication, description -from serverauditor_sshconfig.core.api import API -from serverauditor_sshconfig.core.cryptor import RNCryptor -from serverauditor_sshconfig.core.logger import PrettyLogger -from serverauditor_sshconfig.core.ssh_config import SSHConfig -from serverauditor_sshconfig.core.utils import p_input, p_map - - -class ExportSSHConfigApplication(SSHConfigApplication): - - def run(self): - self._greeting() - - self._get_sa_user() - self._get_sa_keys_and_connections() - self._decrypt_sa_keys_and_connections() - self._fix_sa_keys_and_connections() - - self._parse_local_config() - self._sync_for_export() - self._choose_new_hosts() - self._get_full_hosts() - - self._create_keys_and_connections() - - self._valediction() - return - - def _greeting(self): - self._logger.log("ServerAuditor's ssh config script. Export from your computer to SA account.", color='magenta') - return - - @description("Synchronization...") - def _sync_for_export(self): - def get_identity_files(host): - return [f[1] for f in host.get('identityfile', [])] - - def is_exist(host): - h = self._config.get_host(host, substitute=True) - has_key = bool(h.get('identityfile', None)) - for conn in self._sa_connections: - key_check = True - key_id = conn['ssh_key'] - if has_key: - if key_id: - key_check = self._sa_keys[key_id['id']]['private_key'] in get_identity_files(h) - else: - continue - else: - if key_id: - continue - - if (conn['hostname'] == h['hostname'] and - conn['ssh_username'] == h['user'] and - conn['port'] == int(h.get('port', 22)) and - key_check): # conn['label'] == h['host'] - return True - - return False - - for host in self._local_hosts[:]: - if is_exist(host): - self._local_hosts.remove(host) - - return - - @description(valediction="OK!") - def _choose_new_hosts(self): - def get_prompt(): - if self._local_hosts: - return "You may confirm this list (press 'Enter'), add (enter '+') or remove (enter its number) host: " - else: - return "You may confirm this list (press 'Enter') or add (enter '+') host: " - - def get_hosts_names(): - return ', '.join('%s (#%d)' % (h, i) for i, h in enumerate(self._local_hosts)) or '[]' - - self._logger.log("The following new hosts have been founded in your ssh config:", sleep=0) - self._logger.log(get_hosts_names(), color='blue') - - while True: - number = p_input(get_prompt()).strip() - - if number == '': - break - - if number == '+': - host = p_input("Enter host: ") - conf = self._config.get_host(host) - if list(conf.keys()) == ['host']: - self._logger.log('There is no config for host "%s"!' % host, color='red', file=sys.stderr) - else: - self._local_hosts.append(host) - - else: - try: - number = int(number) - if number >= len(self._local_hosts) or number < 0: - raise IndexError - except (ValueError, IndexError): - self._logger.log("Incorrect index!", color='red', file=sys.stderr) - continue - else: - self._local_hosts.pop(number) - - self._logger.log(get_hosts_names(), color='blue') - - if not self._local_hosts: - self._valediction() - sys.exit(0) - return - - @description("Getting full information...") - def _get_full_hosts(self): - def check_duplicates(hosts): - new_hosts = [] - new_hosts_ids = set() - new_hosts_names = {} - - # current serverauditor connections - for conn in self._sa_connections: - conn_id = self._get_sa_connection_uri(conn) - new_hosts_ids.add(conn_id) - new_hosts_names[conn_id] = self._get_sa_connection_name(conn) - - for host in hosts: - host_id = '{host[user]}@{host[hostname]}:{host[port]}'.format(host=host) - if not host_id in new_hosts_ids: - new_hosts_ids.add(host_id) - new_hosts.append(host) - new_hosts_names[host_id] = host['host'] - else: - self._logger.log('Seems "{cur_host}" is duplicate of "{ex_host}"!'.format( - cur_host=host['host'], - ex_host=new_hosts_names[host_id] - ), color='blue') - return new_hosts - - def encrypt_host(host): - host['host'] = self._cryptor.encrypt(host['host']) - host['hostname'] = self._cryptor.encrypt(host['hostname']) - host['user'] = self._cryptor.encrypt(host['user']) - host['password'] = '' - - host['ssh_key'] = [] - for i, f in enumerate(host.get('identityfile', [])): - ssh_key = { - 'label': self._cryptor.encrypt(f[0]), - 'private_key': self._cryptor.encrypt(f[1]), - 'public_key': '', - 'passphrase': '' - } - host['ssh_key'].append(ssh_key) - return host - - almost_full_local_hosts = [self._config.get_host(h, substitute=True) for h in self._local_hosts] - full_local_hosts = check_duplicates(almost_full_local_hosts) - self._full_local_hosts = p_map(encrypt_host, full_local_hosts) - return - - @description("Creating keys and connections...") - def _create_keys_and_connections(self): - self._api.create_keys_and_connections(self._full_local_hosts, self._sa_username, self._sa_auth_key) - return - - -def main(): - app = ExportSSHConfigApplication(api=API(), ssh_config=SSHConfig(), cryptor=RNCryptor(), logger=PrettyLogger()) - try: - app.run() - except (KeyboardInterrupt, EOFError): - sys.exit(1) - return - - -if __name__ == "__main__": - - main() diff --git a/serverauditor_sshconfig/sa_import.py b/serverauditor_sshconfig/sa_import.py deleted file mode 100644 index f779dcc..0000000 --- a/serverauditor_sshconfig/sa_import.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -""" -Copyright (c) 2013 Crystalnix. -License BSD, see LICENSE for more details. -""" - -import os -import stat -import sys - -from serverauditor_sshconfig.core.application import SSHConfigApplication, description -from serverauditor_sshconfig.core.api import API -from serverauditor_sshconfig.core.cryptor import RNCryptor -from serverauditor_sshconfig.core.logger import PrettyLogger -from serverauditor_sshconfig.core.ssh_config import SSHConfig -from serverauditor_sshconfig.core.utils import p_input - - -class ImportSSHConfigApplication(SSHConfigApplication): - - SSH_KEYS_DIR = '~/.ssh/serverauditor/' - SSH_CONFIG_HOST_TEMPLATE = """\ -# The following host was created by ServerAuditor -Host {host} - User {user} - HostName {hostname} - Port {port} -""" - SSH_CONFIG_HOST_IDENTITY_FILE = """\ - IdentityFile {key} -""" - SSH_CONFIG_HOST_DUPLICATE = """\ -# Duplicate host will not work with SSH -""" - - def run(self): - self._greeting() - - self._get_sa_user() - self._get_sa_keys_and_connections() - self._decrypt_sa_keys_and_connections() - self._fix_sa_keys_and_connections() - - self._parse_local_config() - self._sync_for_import() - self._choose_new_hosts() - self._create_keys_and_connections() - - self._valediction() - return - - def _greeting(self): - self._logger.log("ServerAuditor's ssh config script. Import from SA account to your computer.", color='magenta') - return - - @description("Synchronization...") - def _sync_for_import(self): - def get_identity_files(host): - return [f[1] for f in host.get('identityfile', [])] - - def is_exist(conn): - attempt = conn['label'] or conn['hostname'] - h = self._config.get_host(attempt, substitute=True) - key_check = True - key_id = conn['ssh_key'] - if key_id: - key_check = self._sa_keys[key_id['id']]['private_key'] in get_identity_files(h) - return (conn['ssh_username'] == h['user'] and - conn['hostname'] == h['hostname'] and - conn['port'] == h['port'] and - key_check) - - for conn in self._sa_connections[:]: - if is_exist(conn): - name = conn['label'] or "%s@%s:%s" % (conn['ssh_username'], conn['hostname'], conn['port']) - self._logger.log('Connection "%s" can already be used by ssh.' % name, color='blue') - self._sa_connections.remove(conn) - - return - - @description(valediction="OK!") - def _choose_new_hosts(self): - def get_connection_name(conn, number): - name = self._get_sa_connection_name(conn) - return '%s (#%d)' % (name, number) - - def get_connections_names(): - return (', '.join(get_connection_name(c, i) for i, c in enumerate(self._sa_connections)) - or 'There are no more connections!') - - if not self._sa_connections: - self._logger.log("There are no new connections on ServerAuditor's servers.") - self._valediction() - sys.exit(0) - - self._logger.log("The following new hosts have been founded on ServerAuditor's servers:", sleep=0) - self._logger.log(get_connections_names(), color='blue') - - prompt = "You may confirm this list (press 'Enter') or remove host (enter its number): " - while len(self._sa_connections): - number = p_input(prompt).strip() - - if number == '': - break - - try: - number = int(number) - if number >= len(self._sa_connections) or number < 0: - raise IndexError - except (ValueError, IndexError): - self._logger.log("Incorrect index!", color='red', file=sys.stderr) - else: - self._sa_connections.pop(number) - self._logger.log(get_connections_names(), color='blue') - - if not self._sa_connections: - self._valediction() - sys.exit(0) - return - - @description("Creating keys and connections...") - def _create_keys_and_connections(self): - def get_param_name(s): - if any(c.isspace() for c in s): - return '"%s"' % s - return s - - def check_ssh_keys_dir(): - key_dir = os.path.expanduser(self.SSH_KEYS_DIR) - if not os.path.exists(key_dir): - os.mkdir(key_dir) - - def get_key_path(key): - key_name = test_key_name = os.path.join(self.SSH_KEYS_DIR, key['label']) - i = 1 - while os.path.exists(os.path.expanduser(test_key_name)): - test_key_name = key_name + '-%d' % i - i += 1 - return test_key_name - - def create_connection(conn): - - name = conn['label'] or conn['hostname'] - is_duplicate = name in self._local_hosts - if is_duplicate: - self._logger.log(('Seems local config already contains host with name "{name}". ' - 'SSH won\'t be able to use the second one.').format(name=name), color='blue') - - with open(self._config.USER_CONFIG_PATH, 'a') as cf: - if is_duplicate: - cf.write(self.SSH_CONFIG_HOST_DUPLICATE) - - host = self.SSH_CONFIG_HOST_TEMPLATE.format( - host=get_param_name(name), - hostname=get_param_name(conn['hostname']), - user=get_param_name(conn['ssh_username']), - port=conn['port'] - ) - cf.write(host) - - if conn['ssh_key']: - key = self._sa_keys[conn['ssh_key']['id']] - key_name = get_key_path(key) - idf = self.SSH_CONFIG_HOST_IDENTITY_FILE.format(key=get_param_name(key_name)) - cf.write(idf) - - cf.write('\n') - - if conn['ssh_key']: - check_ssh_keys_dir() - key = self._sa_keys[conn['ssh_key']['id']] - key_name = os.path.expanduser(get_key_path(key)) - - if key['private_key']: - with open(key_name, 'w') as private_file: - private_file.write(key['private_key']) - os.chmod(key_name, stat.S_IWUSR | stat.S_IRUSR) - - if key['public_key']: - with open(key_name + '.pub', 'w') as public_file: - public_file.write(key['public_key']) - os.chmod(key_name + '.pub', stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) - - return - - for conn in self._sa_connections: - create_connection(conn) - - return - - -def main(): - app = ImportSSHConfigApplication(api=API(), ssh_config=SSHConfig(), cryptor=RNCryptor(), logger=PrettyLogger()) - try: - app.run() - except (KeyboardInterrupt, EOFError): - sys.exit(1) - return - - -if __name__ == "__main__": - - main() diff --git a/setup.py b/setup.py index 82350d9..be24d3f 100644 --- a/setup.py +++ b/setup.py @@ -62,8 +62,8 @@ def get_long_description(): 'tags = serverauditor_sshconfig.handlers:TagsCommand', 'login = serverauditor_sshconfig.account.commands:LoginCommand', 'logout = serverauditor_sshconfig.account.commands:LogoutCommand', - 'push = serverauditor_sshconfig.handlers:PushCommand', - 'pull = serverauditor_sshconfig.handlers:PullCommand', + 'push = serverauditor_sshconfig.cloud.commands:PushCommand', + 'pull = serverauditor_sshconfig.cloud.commands:PullCommand', 'info = serverauditor_sshconfig.handlers:InfoCommand', 'connect = serverauditor_sshconfig.handlers:ConnectCommand', ], diff --git a/tests/cloud/__init__.py b/tests/cloud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cloud/test_models.py b/tests/cloud/test_models.py new file mode 100644 index 0000000..a103c98 --- /dev/null +++ b/tests/cloud/test_models.py @@ -0,0 +1,40 @@ +import six +from collections import OrderedDict +from mock import patch, Mock +from serverauditor_sshconfig.cloud.models import ( + Host, Group, Tag, SshKey, SshIdentity, SshConfig, Group, Host, PFRule +) + +from serverauditor_sshconfig.core.settings import ApplicationStorage + + +def test_generator(): + model_classes = ( + Host, Group, Tag, + SshKey, SshIdentity, SshConfig, + Group, Host, PFRule + ) + for model_class in model_classes: + instance = model_class() + not_fk_instance = (i for i in instance.fields if i not in instance.mapping) + for i in not_fk_instance: + setattr(instance, i, i) + + yield save, instance + + +@patch('serverauditor_sshconfig.core.settings.PersistentDict') +def save(model, mocked): + + storage = ApplicationStorage('test') + storage.save(model) + + assert isinstance(model.id, six.integer_types) + + for k, v in model.mapping.items(): + setattr(model, k, [] if v.many else None) + stored_models = OrderedDict(((model.id, model),)) + + driver = mocked.return_value + driver.__setitem__.assert_called_with(model.set_name, stored_models) + driver.sync.assert_called_with() diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/host.bats b/tests/integration/host.bats index b3858e7..bc77ae4 100644 --- a/tests/integration/host.bats +++ b/tests/integration/host.bats @@ -11,7 +11,7 @@ } @test "Add general host" { - rm ~/.serverauditor.storage + rm ~/.serverauditor.storage || true run serverauditor host -L test --port 2022 --username root --password password [ "$status" -eq 0 ] ! [ -z $(cat ~/.serverauditor.storage) ] diff --git a/tests/integration/hosts.bats b/tests/integration/hosts.bats index 3bd809d..c754150 100644 --- a/tests/integration/hosts.bats +++ b/tests/integration/hosts.bats @@ -12,7 +12,7 @@ } @test "List hosts in table format" { - rm ~/.serverauditor.storage + rm ~/.serverauditor.storage || true serverauditor host -L test --port 2022 --username root --password password run serverauditor hosts [ "$status" -eq 0 ] diff --git a/tests/integration/login.bats b/tests/integration/login.bats index 4d10f80..2629599 100644 --- a/tests/integration/login.bats +++ b/tests/integration/login.bats @@ -14,6 +14,8 @@ if [ "$Serverauditor_username" == '' ] || [ "$Serverauditor_password" == '' ];then skip fi + rm ~/.serverauditor || true + run serverauditor login --username $Serverauditor_username -p $Serverauditor_password echo $output [ "$status" -eq 0 ] diff --git a/tests/integration/logout.bats b/tests/integration/logout.bats index f6de1d1..843ae2f 100644 --- a/tests/integration/logout.bats +++ b/tests/integration/logout.bats @@ -14,7 +14,10 @@ if [ "$Serverauditor_username" == '' ] || [ "$Serverauditor_password" == '' ];then skip fi + + rm ~/.serverauditor || true serverauditor login --username $Serverauditor_username -p$Serverauditor_password + run serverauditor logout [ "$status" -eq 0 ] [ -z $(cat ~/.serverauditor) ] diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..8693c4e --- /dev/null +++ b/tox.ini @@ -0,0 +1,23 @@ +# Tox (http://tox.testrun.org/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. + +[tox] +envlist = py27,py30,py34 +skipsdist = True + +[testenv] +deps = + mock + nose + coverage + prospector +commands = + pip install -U . + nosetests -c .noserc + prospector + +[flake8] +exclude = .git* +filename = *.py \ No newline at end of file