163 lines
6.7 KiB
Python
163 lines
6.7 KiB
Python
# Licensed to the Software Freedom Conservancy (SFC) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The SFC licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import base64
|
|
import os
|
|
import socket
|
|
from enum import Enum
|
|
from typing import Optional
|
|
from urllib import parse
|
|
|
|
import certifi
|
|
|
|
from selenium.webdriver.common.proxy import Proxy
|
|
from selenium.webdriver.common.proxy import ProxyType
|
|
|
|
|
|
class AuthType(Enum):
|
|
BASIC = "Basic"
|
|
BEARER = "Bearer"
|
|
X_API_KEY = "X-API-Key"
|
|
|
|
|
|
class _ClientConfigDescriptor:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def __get__(self, obj, cls):
|
|
return obj.__dict__[self.name]
|
|
|
|
def __set__(self, obj, value) -> None:
|
|
obj.__dict__[self.name] = value
|
|
|
|
|
|
class ClientConfig:
|
|
remote_server_addr = _ClientConfigDescriptor("_remote_server_addr")
|
|
"""Gets and Sets Remote Server."""
|
|
keep_alive = _ClientConfigDescriptor("_keep_alive")
|
|
"""Gets and Sets Keep Alive value."""
|
|
proxy = _ClientConfigDescriptor("_proxy")
|
|
"""Gets and Sets the proxy used for communicating to the driver/server."""
|
|
ignore_certificates = _ClientConfigDescriptor("_ignore_certificates")
|
|
"""Gets and Sets the ignore certificate check value."""
|
|
init_args_for_pool_manager = _ClientConfigDescriptor("_init_args_for_pool_manager")
|
|
"""Gets and Sets the ignore certificate check."""
|
|
timeout = _ClientConfigDescriptor("_timeout")
|
|
"""Gets and Sets the timeout (in seconds) used for communicating to the
|
|
driver/server."""
|
|
ca_certs = _ClientConfigDescriptor("_ca_certs")
|
|
"""Gets and Sets the path to bundle of CA certificates."""
|
|
username = _ClientConfigDescriptor("_username")
|
|
"""Gets and Sets the username used for basic authentication to the
|
|
remote."""
|
|
password = _ClientConfigDescriptor("_password")
|
|
"""Gets and Sets the password used for basic authentication to the
|
|
remote."""
|
|
auth_type = _ClientConfigDescriptor("_auth_type")
|
|
"""Gets and Sets the type of authentication to the remote server."""
|
|
token = _ClientConfigDescriptor("_token")
|
|
"""Gets and Sets the token used for authentication to the remote server."""
|
|
user_agent = _ClientConfigDescriptor("_user_agent")
|
|
"""Gets and Sets user agent to be added to the request headers."""
|
|
extra_headers = _ClientConfigDescriptor("_extra_headers")
|
|
"""Gets and Sets extra headers to be added to the request."""
|
|
|
|
def __init__(
|
|
self,
|
|
remote_server_addr: str,
|
|
keep_alive: Optional[bool] = True,
|
|
proxy: Optional[Proxy] = Proxy(raw={"proxyType": ProxyType.SYSTEM}),
|
|
ignore_certificates: Optional[bool] = False,
|
|
init_args_for_pool_manager: Optional[dict] = None,
|
|
timeout: Optional[int] = None,
|
|
ca_certs: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
auth_type: Optional[AuthType] = AuthType.BASIC,
|
|
token: Optional[str] = None,
|
|
user_agent: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
) -> None:
|
|
self.remote_server_addr = remote_server_addr
|
|
self.keep_alive = keep_alive
|
|
self.proxy = proxy
|
|
self.ignore_certificates = ignore_certificates
|
|
self.init_args_for_pool_manager = init_args_for_pool_manager or {}
|
|
self.timeout = timeout
|
|
self.username = username
|
|
self.password = password
|
|
self.auth_type = auth_type
|
|
self.token = token
|
|
self.user_agent = user_agent
|
|
self.extra_headers = extra_headers
|
|
|
|
self.timeout = (
|
|
(
|
|
float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout())))
|
|
if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None
|
|
else socket.getdefaulttimeout()
|
|
)
|
|
if timeout is None
|
|
else timeout
|
|
)
|
|
|
|
self.ca_certs = (
|
|
(os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where())
|
|
if ca_certs is None
|
|
else ca_certs
|
|
)
|
|
|
|
def reset_timeout(self) -> None:
|
|
"""Resets the timeout to the default value of socket."""
|
|
self._timeout = socket.getdefaulttimeout()
|
|
|
|
def get_proxy_url(self) -> Optional[str]:
|
|
"""Returns the proxy URL to use for the connection."""
|
|
proxy_type = self.proxy.proxy_type
|
|
remote_add = parse.urlparse(self.remote_server_addr)
|
|
if proxy_type is ProxyType.DIRECT:
|
|
return None
|
|
if proxy_type is ProxyType.SYSTEM:
|
|
_no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY"))
|
|
if _no_proxy:
|
|
for entry in map(str.strip, _no_proxy.split(",")):
|
|
if entry == "*":
|
|
return None
|
|
n_url = parse.urlparse(entry)
|
|
if n_url.netloc and remote_add.netloc == n_url.netloc:
|
|
return None
|
|
if n_url.path in remote_add.netloc:
|
|
return None
|
|
return os.environ.get(
|
|
"https_proxy" if self.remote_server_addr.startswith("https://") else "http_proxy",
|
|
os.environ.get("HTTPS_PROXY" if self.remote_server_addr.startswith("https://") else "HTTP_PROXY"),
|
|
)
|
|
if proxy_type is ProxyType.MANUAL:
|
|
return self.proxy.sslProxy if self.remote_server_addr.startswith("https://") else self.proxy.http_proxy
|
|
return None
|
|
|
|
def get_auth_header(self) -> Optional[dict]:
|
|
"""Returns the authorization to add to the request headers."""
|
|
if self.auth_type is AuthType.BASIC and self.username and self.password:
|
|
credentials = f"{self.username}:{self.password}"
|
|
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
|
return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"}
|
|
if self.auth_type is AuthType.BEARER and self.token:
|
|
return {"Authorization": f"{AuthType.BEARER.value} {self.token}"}
|
|
if self.auth_type is AuthType.X_API_KEY and self.token:
|
|
return {f"{AuthType.X_API_KEY.value}": f"{self.token}"}
|
|
return None
|