aboutsummaryrefslogtreecommitdiffstats
path: root/Demos/Device/ClassDriver/KeyboardMouse/KeyboardMouse.c
blob: 6ba7ce3d0c12ef0e77582aec1155eecf375dfb88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
/*
             LUFA Library
     Copyright (C) Dean Camera, 2010.
              
  dean [at] fourwalledcubicle [dot] com
      www.fourwalledcubicle.com
*/

/*
  Copyright 2010  Dean Camera (dean [at] fourwalledcubicle [dot] com)
	  
  Permission to use, copy, modify, distribute, and sell this 
  software and its documentation for any purpose is hereby granted
  without fee, provided that the above copyright notice appear in 
  all copies and that both that the copyright notice and this
  permission notice and warranty disclaimer appear in supporting 
  documentation, and that the name of the author not be used in 
  advertising or publicity pertaining to distribution of the 
  software without specific, written prior permission.

  The author disclaim all warranties with regard to this
  software, including all implied warranties of merchantability
  and fitness.  In no event shall the author be liable for any
  special, indirect or consequential damages or any damages
  whatsoever resulting from loss of use, data or profits, whether
  in an action of contract, negligence or other tortious action,
  arising out of or in connection with the use or performance of
  this software.
*/

/** \file
 *
 *  Main source file for the KeyboardMouse demo. This file contains the main tasks of
 *  the demo and is responsible for the initial application hardware configuration.
 */

#include "KeyboardMouse.h"

/** Buffer to hold the previously generated Keyboard HID report, for comparison purposes inside the HID class driver. */
uint8_t PrevKeyboardHIDReportBuffer[sizeof(USB_KeyboardReport_Data_t)];

/** Buffer to hold the previously generated Mouse HID report, for comparison purposes inside the HID class driver. */
uint8_t PrevMouseHIDReportBuffer[sizeof(USB_MouseReport_Data_t)];

/** LUFA HID Class driver interface configuration and state information. This structure is
 *  passed to all HID Class driver functions, so that multiple instances of the same class
 *  within a device can be differentiated from one another. This is for the keyboard HID
 *  interface within the device.
 */
USB_ClassInfo_HID_Device_t Keyboard_HID_Interface =
	{
		.Config =
			{
				.InterfaceNumber              = 0,

				.ReportINEndpointNumber       = KEYBOARD_IN_EPNUM,
				.ReportINEndpointSize         = HID_EPSIZE,
				.ReportINEndpointDoubleBank   = false,

				.PrevReportINBuffer           = PrevKeyboardHIDReportBuffer,
				.PrevReportINBufferSize       = sizeof(PrevKeyboardHIDReportBuffer),
			},
	};
	
/** LUFA HID Class driver interface configuration and state information. This structure is
 *  passed to all HID Class driver functions, so that multiple instances of the same class
 *  within a device can be differentiated from one another. This is for the mouse HID
 *  interface within the device.
 */
USB_ClassInfo_HID_Device_t Mouse_HID_Interface =
	{
		.Config =
			{
				.InterfaceNumber              = 1,

				.ReportINEndpointNumber       = MOUSE_IN_EPNUM,
				.ReportINEndpointSize         = HID_EPSIZE,

				.PrevReportINBuffer           = PrevMouseHIDReportBuffer,
				.PrevReportINBufferSize       = sizeof(PrevMouseHIDReportBuffer),
			},		
	};

/** Main program entry point. This routine contains the overall program flow, including initial
 *  setup of all components and the main program loop.
 */
int main(void)
{
	SetupHardware();

	LEDs_SetAllLEDs(LEDMASK_USB_NOTREADY);
	sei();

	for (;;)
	{
		HID_Device_USBTask(&Keyboard_HID_Interface);
		HID_Device_USBTask(&Mouse_HID_Interface);
		USB_USBTask();
	}
}

/** Configures the board hardware and chip peripherals for the demo's functionality. */
void SetupHardware()
{
	/* Disable watchdog if enabled by bootloader/fuses */
	MCUSR &= ~(1 << WDRF);
	wdt_disable();

	/* Disable clock division */
	clock_prescale_set(clock_div_1);

	/* Hardware Initialization */
	Joystick_Init();
	LEDs_Init();
	USB_Init();
}

/** Event handler for the library USB Connection event. */
void EVENT_USB_Device_Connect(void)
{
    LEDs_SetAllLEDs(LEDMASK_USB_ENUMERATING);
}

/** Event handler for the library USB Disconnection event. */
void EVENT_USB_Device_Disconnect(void)
{
    LEDs_SetAllLEDs(LEDMASK_USB_NOTREADY);
}

/** Event handler for the library USB Configuration Changed event. */
void EVENT_USB_Device_ConfigurationChanged(void)
{
	bool ConfigSuccess = true;

	ConfigSuccess &= HID_Device_ConfigureEndpoints(&Keyboard_HID_Interface);
	ConfigSuccess &= HID_Device_ConfigureEndpoints(&Mouse_HID_Interface);

	USB_Device_EnableSOFEvents();

	LEDs_SetAllLEDs(ConfigSuccess ? LEDMASK_USB_READY : LEDMASK_USB_ERROR);
}

/** Event handler for the library USB Unhandled Control Request event. */
void EVENT_USB_Device_UnhandledControlRequest(void)
{
	HID_Device_ProcessControlRequest(&Keyboard_HID_Interface);
	HID_Device_ProcessControlRequest(&Mouse_HID_Interface);
}

/** Event handler for the USB device Start Of Frame event. */
void EVENT_USB_Device_StartOfFrame(void)
{
	HID_Device_MillisecondElapsed(&Keyboard_HID_Interface);
	HID_Device_MillisecondElapsed(&Mouse_HID_Interface);
}

/** HID class driver callback function for the creation of HID reports to the host.
 *
 *  \param[in]     HIDInterfaceInfo  Pointer to the HID class interface configuration structure being referenced
 *  \param[in,out] ReportID  Report ID requested by the host if non-zero, otherwise callback should set to the generated report ID
 *  \param[in]     ReportType  Type of the report to create, either REPORT_ITEM_TYPE_In or REPORT_ITEM_TYPE_Feature
 *  \param[out]    ReportData  Pointer to a buffer where the created report should be stored
 *  \param[out]    ReportSize  Number of bytes written in the report (or zero if no report is to be sent
 *
 *  \return Boolean true to force the sending of the report, false to let the library determine if it needs to be sent
 */
bool CALLBACK_HID_Device_CreateHIDReport(USB_ClassInfo_HID_Device_t* const HIDInterfaceInfo,
                                         uint8_t* const ReportID,
                                         const uint8_t ReportType,
                                         void* ReportData,
                                         uint16_t* const ReportSize)
{
	uint8_t JoyStatus_LCL    = Joystick_GetStatus();
	uint8_t ButtonStatus_LCL = Buttons_GetStatus();

	/* Determine which interface must have its report generated */
	if (HIDInterfaceInfo == &Keyboard_HID_Interface)
	{
		USB_KeyboardReport_Data_t* KeyboardReport = (USB_KeyboardReport_Data_t*)ReportData;
		
		/* If first board button not being held down, no keyboard report */
		if (!(ButtonStatus_LCL & BUTTONS_BUTTON1))
		  return 0;
		
		KeyboardReport->Modifier = HID_KEYBOARD_MODIFER_LEFTSHIFT;

		if (JoyStatus_LCL & JOY_UP)
		  KeyboardReport->KeyCode[0] = 0x04; // A
		else if (JoyStatus_LCL & JOY_DOWN)
		  KeyboardReport->KeyCode[0] = 0x05; // B

		if (JoyStatus_LCL & JOY_LEFT)
		  KeyboardReport->KeyCode[0] = 0x06; // C
		else if (JoyStatus_LCL & JOY_RIGHT)
		  KeyboardReport->KeyCode[0] = 0x07; // D

		if (JoyStatus_LCL & JOY_PRESS)
		  KeyboardReport->KeyCode[0] = 0x08; // E
		
		*ReportSize = sizeof(USB_KeyboardReport_Data_t);
		return false;
	}
	else
	{
		USB_MouseReport_Data_t* MouseReport = (USB_MouseReport_Data_t*)ReportData;

		/* If first board button being held down, no mouse report */
		if (ButtonStatus_LCL & BUTTONS_BUTTON1)
		  return 0;
		  
		if (JoyStatus_LCL & JOY_UP)
		  MouseReport->Y = -1;
		else if (JoyStatus_LCL & JOY_DOWN)
		  MouseReport->Y =  1;

		if (JoyStatus_LCL & JOY_LEFT)
		  MouseReport->X = -1;
		else if (JoyStatus_LCL & JOY_RIGHT)
		  MouseReport->X =  1;

		if (JoyStatus_LCL & JOY_PRESS)
		  MouseReport->Button |= (1 << 0);
		
		*ReportSize = sizeof(USB_MouseReport_Data_t);
		return true;		
	}
}

/** HID class driver callback function for the processing of HID reports from the host.
 *
 *  \param[in] HIDInterfaceInfo  Pointer to the HID class interface configuration structure being referenced
 *  \param[in] ReportID    Report ID of the received report from the host
 *  \param[in] ReportType  The type of report that the host has sent, either REPORT_ITEM_TYPE_Out or REPORT_ITEM_TYPE_Feature
 *  \param[in] ReportData  Pointer to a buffer where the created report has been stored
 *  \param[in] ReportSize  Size in bytes of the received HID report
 */
void CALLBACK_HID_Device_ProcessHIDReport(USB_ClassInfo_HID_Device_t* const HIDInterfaceInfo,
                                          const uint8_t ReportID,
                                          const uint8_t ReportType,
                                          const void* ReportData,
                                          const uint16_t ReportSize)
{
	if (HIDInterfaceInfo == &Keyboard_HID_Interface)
	{
		uint8_t  LEDMask   = LEDS_NO_LEDS;
		uint8_t* LEDReport = (uint8_t*)ReportData;

		if (*LEDReport & HID_KEYBOARD_LED_NUMLOCK)
		  LEDMask |= LEDS_LED1;
		
		if (*LEDReport & HID_KEYBOARD_LED_CAPSLOCK)
		  LEDMask |= LEDS_LED3;

		if (*LEDReport & HID_KEYBOARD_LED_SCROLLLOCK)
		  LEDMask |= LEDS_LED4;
		  
		LEDs_SetAllLEDs(LEDMask);
	}
}
span class="n">getenv("SSLKEYLOGFILE")) class _FileLike(object): BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o self._log = None self.first_byte_timestamp = None def set_descriptor(self, o): self.o = o def __getattr__(self, attr): return getattr(self.o, attr) def start_log(self): """ Starts or resets the log. This will store all bytes read or written. """ self._log = [] def stop_log(self): """ Stops the log. """ self._log = None def is_logging(self): return self._log is not None def get_log(self): """ Returns the log as a string. """ if not self.is_logging(): raise ValueError("Not logging!") return "".join(self._log) def add_log(self, v): if self.is_logging(): self._log.append(v) def reset_timestamps(self): self.first_byte_timestamp = None class Writer(_FileLike): def flush(self): """ May raise NetLibDisconnect """ if hasattr(self.o, "flush"): try: self.o.flush() except (socket.error, IOError) as v: raise NetLibDisconnect(str(v)) def write(self, v): """ May raise NetLibDisconnect """ if v: self.first_byte_timestamp = self.first_byte_timestamp or time.time() try: if hasattr(self.o, "sendall"): self.add_log(v) return self.o.sendall(v) else: r = self.o.write(v) self.add_log(v[:r]) return r except (SSL.Error, socket.error) as e: raise NetLibDisconnect(str(e)) class Reader(_FileLike): def read(self, length): """ If length is -1, we read until connection closes. """ result = '' start = time.time() while length == -1 or length > 0: if length == -1 or length > self.BLOCKSIZE: rlen = self.BLOCKSIZE else: rlen = length try: data = self.o.read(rlen) except SSL.ZeroReturnError: break except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) continue else: raise NetLibTimeout except socket.timeout: raise NetLibTimeout except socket.error: raise NetLibDisconnect except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break raise NetLibSSLError(e.message) except SSL.Error as e: raise NetLibSSLError(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break result += data if length != -1: length -= len(data) self.add_log(result) return result def readline(self, size=None): result = '' bytes_read = 0 while True: if size is not None and bytes_read >= size: break ch = self.read(1) bytes_read += 1 if not ch: break else: result += ch if ch == '\n': break return result def safe_read(self, length): """ Like .read, but is guaranteed to either return length bytes, or raise an exception. """ result = self.read(length) if length != -1 and len(result) != length: if not result: raise NetLibDisconnect() else: raise NetLibIncomplete( "Expected %s bytes, got %s" % (length, len(result)) ) return result class Address(object): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = tuple(address) self.use_ipv6 = use_ipv6 @classmethod def wrap(cls, t): if isinstance(t, cls): return t else: return cls(t) def __call__(self): return self.address @property def host(self): return self.address[0] @property def port(self): return self.address[1] @property def use_ipv6(self): return self.family == socket.AF_INET6 @use_ipv6.setter def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET def __repr__(self): return "{}:{}".format(self.host, self.port) def __str__(self): return str(self.address) def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) def __ne__(self, other): return not self.__eq__(other) def close_socket(sock): """ Does a hard close of a socket, without emitting a RST. """ try: # We already indicate that we close our end. # may raise "Transport endpoint is not connected" on Linux sock.shutdown(socket.SHUT_WR) # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending # readable data could lead to an immediate RST being sent (which is the # case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # # This in turn results in the following issue: If we send an error page # to the client and then close the socket, the RST may be received by # the client before the error page and the users sees a connection # error rather than the error page. Thus, we try to empty the read # buffer on Windows first. (see # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) # if os.name == "nt": # pragma: no cover # We cannot rely on the shutdown()-followed-by-read()-eof technique # proposed by the page above: Some remote machines just don't send # a TCP FIN, which would leave us in the unfortunate situation that # recv() would block infinitely. As a workaround, we set a timeout # here even if we are in blocking mode. sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely for _ in xrange(1024 ** 3 // 4096): # may raise a timeout/disconnect exception. if not sock.recv(4096): break # Now we can close the other half as well. sock.shutdown(socket.SHUT_RD) except socket.error: pass sock.close() class _Connection(object): def get_current_cipher(self): if not self.ssl_established: return None name = self.connection.get_cipher_name() bits = self.connection.get_cipher_bits() version = self.connection.get_cipher_version() return name, bits, version def finish(self): self.finished = True # If we have an SSL connection, wfile.close == connection.close # (We call _FileLike.set_descriptor(conn)) # Closing the socket is not our task, therefore we don't call close # then. if not isinstance(self.connection, SSL.Connection): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() self.wfile.close() except NetLibDisconnect: pass self.rfile.close() else: try: self.connection.shutdown() except SSL.Error: pass """ Creates an SSL Context. """ def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), verify_options=VERIFY_NONE, ca_path=None, ca_pemfile=None, cipher_list=None, alpn_protos=None, alpn_select=None, ): """ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool :param ca_pemfile: Path to a PEM formatted trusted CA certificate :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context """ context = SSL.Context(method) # Options (NO_SSLv2/3) if options is not None: context.set_options(options) # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) if verify_options is not None and verify_options is not VERIFY_NONE: def verify_cert(conn, cert, errno, err_depth, is_cert_verified): if is_cert_verified: return True raise NetLibError( "Upstream certificate validation failed at depth: %s with error number: %s" % (err_depth, errno)) context.set_verify(verify_options, verify_cert) if ca_path is not None or ca_pemfile is not None: context.load_verify_locations(ca_pemfile, ca_path) # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 # Options already set before are not cleared. context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Cipher List if cipher_list: try: context.set_cipher_list(cipher_list) except SSL.Error as v: raise NetLibError("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: context.set_info_callback(log_ssl_key) if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols context.set_alpn_protos(alpn_protos) elif alpn_select is not None: # select application layer protocol def alpn_select_callback(conn, options): if alpn_select in options: return bytes(alpn_select) else: # pragma no cover return options[0] context.set_alpn_select_callback(alpn_select_callback) return context class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 def close(self): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... if isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap( source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False self.sni = None def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): context = self._create_ssl_context( alpn_protos=alpn_protos, **sslctx_kwargs) # Client Certs if cert: try: context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error as v: raise NetLibError("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): """ cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool ca_pemfile: Path to a PEM formatted trusted CA certificate """ context = self.create_ssl_context( alpn_protos=alpn_protos, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() try: self.connection.do_handshake() except SSL.Error as v: raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) def connect(self): try: connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: raise NetLibError( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection def settimeout(self, n): self.connection.settimeout(n) def gettimeout(self): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return None class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. """ rbufsize = -1 wbufsize = -1 def __init__(self, connection, address, server): self.connection = connection self.address = Address.wrap(address) self.server = server self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) self.finished = False self.ssl_established = False self.clientcert = None def create_ssl_context(self, cert, key, handle_sni=None, request_client_cert=None, chain_file=None, dhparams=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: connection.get_servername() And you can specify the connection keys as follows: new_context = Context(TLSv1_METHOD) new_context.use_privatekey(key) new_context.use_certificate(cert) connection.set_context(new_context) The request_client_cert argument requires some explanation. We're supposed to be able to do this with no negative effects - if the client has no cert to present, we're notified and proceed as usual. Unfortunately, Android seems to have a bug (tested on 4.2.2) - when an Android client is asked to present a certificate it does not have, it hangs up, which is frankly bogus. Some time down the track we may be able to make the proper behaviour the default again, but until then we're conservative. """ context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) if handle_sni: # SNI callback happens during do_handshake() context.set_tlsext_servername_callback(handle_sni) if request_client_cert: def save_cert(conn, cert, errno, depth, preverify_ok): self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True context.set_verify(SSL.VERIFY_PEER, save_cert) # Cert Verify if chain_file: context.load_verify_locations(chain_file) if dhparams: SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) return context def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ context = self.create_ssl_context( cert, key, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: self.connection.do_handshake() except SSL.Error as v: raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): self.connection.settimeout(n) def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return None class TCPServer(object): request_queue_size = 20 def __init__(self, address): self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind(self.address()) self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) def connection_thread(self, connection, client_address): client_address = Address(client_address) try: self.handle_client_connection(connection, client_address) except: self.handle_error(connection, client_address) finally: close_socket(connection) def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() try: while not self.__shutdown_request: try: r, w, e = select.select( [self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue else: raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( target=self.connection_thread, args=(connection, client_address), name="ConnectionThread (%s:%s -> %s:%s)" % (client_address[0], client_address[1], self.address.host, self.address.port) ) t.setDaemon(1) try: t.start() except threading.ThreadError: self.handle_error(connection, Address(client_address)) connection.close() finally: self.__shutdown_request = False self.__is_shut_down.set() def shutdown(self): self.__shutdown_request = True self.__is_shut_down.wait() self.socket.close() self.handle_shutdown() def handle_error(self, connection, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be # none. if traceback: exc = traceback.format_exc() print('-' * 40, file=fp) print( "Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ raise NotImplementedError def handle_shutdown(self): """ Called after server shutdown. """