Statistics
| Branch: | Tag: | Revision:

arvados / sdk / python / arvados / keep.py @ 76d9365a

History | View | Annotate | Download (47.4 KB)

1
from __future__ import absolute_import
2
from __future__ import division
3
from future import standard_library
4
standard_library.install_aliases()
5
from builtins import next
6
from builtins import str
7
from builtins import range
8
from builtins import object
9
import io
10
import datetime
11
import hashlib
12
import logging
13
import math
14
import os
15
import pycurl
16
import queue
17
import re
18
import socket
19
import ssl
20
import sys
21
import threading
22
from . import timer
23
import urllib.parse
24

    
25
if sys.version_info >= (3, 0):
26
    from io import BytesIO
27
else:
28
    from cStringIO import StringIO as BytesIO
29

    
30
import arvados
31
import arvados.config as config
32
import arvados.errors
33
import arvados.retry as retry
34
import arvados.util
35

    
36
_logger = logging.getLogger('arvados.keep')
37
global_client_object = None
38

    
39

    
40
# Monkey patch TCP constants when not available (apple). Values sourced from:
41
# http://www.opensource.apple.com/source/xnu/xnu-2422.115.4/bsd/netinet/tcp.h
42
if sys.platform == 'darwin':
43
    if not hasattr(socket, 'TCP_KEEPALIVE'):
44
        socket.TCP_KEEPALIVE = 0x010
45
    if not hasattr(socket, 'TCP_KEEPINTVL'):
46
        socket.TCP_KEEPINTVL = 0x101
47
    if not hasattr(socket, 'TCP_KEEPCNT'):
48
        socket.TCP_KEEPCNT = 0x102
49

    
50

    
51
class KeepLocator(object):
52
    EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
53
    HINT_RE = re.compile(r'^[A-Z][A-Za-z0-9@_-]+$')
54

    
55
    def __init__(self, locator_str):
56
        self.hints = []
57
        self._perm_sig = None
58
        self._perm_expiry = None
59
        pieces = iter(locator_str.split('+'))
60
        self.md5sum = next(pieces)
61
        try:
62
            self.size = int(next(pieces))
63
        except StopIteration:
64
            self.size = None
65
        for hint in pieces:
66
            if self.HINT_RE.match(hint) is None:
67
                raise ValueError("invalid hint format: {}".format(hint))
68
            elif hint.startswith('A'):
69
                self.parse_permission_hint(hint)
70
            else:
71
                self.hints.append(hint)
72

    
73
    def __str__(self):
74
        return '+'.join(
75
            str(s) for s in [self.md5sum, self.size,
76
                             self.permission_hint()] + self.hints
77
            if s is not None)
78

    
79
    def stripped(self):
80
        if self.size is not None:
81
            return "%s+%i" % (self.md5sum, self.size)
82
        else:
83
            return self.md5sum
84

    
85
    def _make_hex_prop(name, length):
86
        # Build and return a new property with the given name that
87
        # must be a hex string of the given length.
88
        data_name = '_{}'.format(name)
89
        def getter(self):
90
            return getattr(self, data_name)
91
        def setter(self, hex_str):
92
            if not arvados.util.is_hex(hex_str, length):
93
                raise ValueError("{} is not a {}-digit hex string: {}".
94
                                 format(name, length, hex_str))
95
            setattr(self, data_name, hex_str)
96
        return property(getter, setter)
97

    
98
    md5sum = _make_hex_prop('md5sum', 32)
99
    perm_sig = _make_hex_prop('perm_sig', 40)
100

    
101
    @property
102
    def perm_expiry(self):
103
        return self._perm_expiry
104

    
105
    @perm_expiry.setter
106
    def perm_expiry(self, value):
107
        if not arvados.util.is_hex(value, 1, 8):
108
            raise ValueError(
109
                "permission timestamp must be a hex Unix timestamp: {}".
110
                format(value))
111
        self._perm_expiry = datetime.datetime.utcfromtimestamp(int(value, 16))
112

    
113
    def permission_hint(self):
114
        data = [self.perm_sig, self.perm_expiry]
115
        if None in data:
116
            return None
117
        data[1] = int((data[1] - self.EPOCH_DATETIME).total_seconds())
118
        return "A{}@{:08x}".format(*data)
119

    
120
    def parse_permission_hint(self, s):
121
        try:
122
            self.perm_sig, self.perm_expiry = s[1:].split('@', 1)
123
        except IndexError:
124
            raise ValueError("bad permission hint {}".format(s))
125

    
126
    def permission_expired(self, as_of_dt=None):
127
        if self.perm_expiry is None:
128
            return False
129
        elif as_of_dt is None:
130
            as_of_dt = datetime.datetime.now()
131
        return self.perm_expiry <= as_of_dt
132

    
133

    
134
class Keep(object):
135
    """Simple interface to a global KeepClient object.
136

137
    THIS CLASS IS DEPRECATED.  Please instantiate your own KeepClient with your
138
    own API client.  The global KeepClient will build an API client from the
139
    current Arvados configuration, which may not match the one you built.
140
    """
141
    _last_key = None
142

    
143
    @classmethod
144
    def global_client_object(cls):
145
        global global_client_object
146
        # Previously, KeepClient would change its behavior at runtime based
147
        # on these configuration settings.  We simulate that behavior here
148
        # by checking the values and returning a new KeepClient if any of
149
        # them have changed.
150
        key = (config.get('ARVADOS_API_HOST'),
151
               config.get('ARVADOS_API_TOKEN'),
152
               config.flag_is_true('ARVADOS_API_HOST_INSECURE'),
153
               config.get('ARVADOS_KEEP_PROXY'),
154
               config.get('ARVADOS_EXTERNAL_CLIENT') == 'true',
155
               os.environ.get('KEEP_LOCAL_STORE'))
156
        if (global_client_object is None) or (cls._last_key != key):
157
            global_client_object = KeepClient()
158
            cls._last_key = key
159
        return global_client_object
160

    
161
    @staticmethod
162
    def get(locator, **kwargs):
163
        return Keep.global_client_object().get(locator, **kwargs)
164

    
165
    @staticmethod
166
    def put(data, **kwargs):
167
        return Keep.global_client_object().put(data, **kwargs)
168

    
169
class KeepBlockCache(object):
170
    # Default RAM cache is 256MiB
171
    def __init__(self, cache_max=(256 * 1024 * 1024)):
172
        self.cache_max = cache_max
173
        self._cache = []
174
        self._cache_lock = threading.Lock()
175

    
176
    class CacheSlot(object):
177
        __slots__ = ("locator", "ready", "content")
178

    
179
        def __init__(self, locator):
180
            self.locator = locator
181
            self.ready = threading.Event()
182
            self.content = None
183

    
184
        def get(self):
185
            self.ready.wait()
186
            return self.content
187

    
188
        def set(self, value):
189
            self.content = value
190
            self.ready.set()
191

    
192
        def size(self):
193
            if self.content is None:
194
                return 0
195
            else:
196
                return len(self.content)
197

    
198
    def cap_cache(self):
199
        '''Cap the cache size to self.cache_max'''
200
        with self._cache_lock:
201
            # Select all slots except those where ready.is_set() and content is
202
            # None (that means there was an error reading the block).
203
            self._cache = [c for c in self._cache if not (c.ready.is_set() and c.content is None)]
204
            sm = sum([slot.size() for slot in self._cache])
205
            while len(self._cache) > 0 and sm > self.cache_max:
206
                for i in range(len(self._cache)-1, -1, -1):
207
                    if self._cache[i].ready.is_set():
208
                        del self._cache[i]
209
                        break
210
                sm = sum([slot.size() for slot in self._cache])
211

    
212
    def _get(self, locator):
213
        # Test if the locator is already in the cache
214
        for i in range(0, len(self._cache)):
215
            if self._cache[i].locator == locator:
216
                n = self._cache[i]
217
                if i != 0:
218
                    # move it to the front
219
                    del self._cache[i]
220
                    self._cache.insert(0, n)
221
                return n
222
        return None
223

    
224
    def get(self, locator):
225
        with self._cache_lock:
226
            return self._get(locator)
227

    
228
    def reserve_cache(self, locator):
229
        '''Reserve a cache slot for the specified locator,
230
        or return the existing slot.'''
231
        with self._cache_lock:
232
            n = self._get(locator)
233
            if n:
234
                return n, False
235
            else:
236
                # Add a new cache slot for the locator
237
                n = KeepBlockCache.CacheSlot(locator)
238
                self._cache.insert(0, n)
239
                return n, True
240

    
241
class Counter(object):
242
    def __init__(self, v=0):
243
        self._lk = threading.Lock()
244
        self._val = v
245

    
246
    def add(self, v):
247
        with self._lk:
248
            self._val += v
249

    
250
    def get(self):
251
        with self._lk:
252
            return self._val
253

    
254

    
255
class KeepClient(object):
256

    
257
    # Default Keep server connection timeout:  2 seconds
258
    # Default Keep server read timeout:       256 seconds
259
    # Default Keep server bandwidth minimum:  32768 bytes per second
260
    # Default Keep proxy connection timeout:  20 seconds
261
    # Default Keep proxy read timeout:        256 seconds
262
    # Default Keep proxy bandwidth minimum:   32768 bytes per second
263
    DEFAULT_TIMEOUT = (2, 256, 32768)
264
    DEFAULT_PROXY_TIMEOUT = (20, 256, 32768)
265

    
266

    
267
    class KeepService(object):
268
        """Make requests to a single Keep service, and track results.
269

270
        A KeepService is intended to last long enough to perform one
271
        transaction (GET or PUT) against one Keep service. This can
272
        involve calling either get() or put() multiple times in order
273
        to retry after transient failures. However, calling both get()
274
        and put() on a single instance -- or using the same instance
275
        to access two different Keep services -- will not produce
276
        sensible behavior.
277
        """
278

    
279
        HTTP_ERRORS = (
280
            socket.error,
281
            ssl.SSLError,
282
            arvados.errors.HttpError,
283
        )
284

    
285
        def __init__(self, root, user_agent_pool=queue.LifoQueue(),
286
                     upload_counter=None,
287
                     download_counter=None, **headers):
288
            self.root = root
289
            self._user_agent_pool = user_agent_pool
290
            self._result = {'error': None}
291
            self._usable = True
292
            self._session = None
293
            self.get_headers = {'Accept': 'application/octet-stream'}
294
            self.get_headers.update(headers)
295
            self.put_headers = headers
296
            self.upload_counter = upload_counter
297
            self.download_counter = download_counter
298

    
299
        def usable(self):
300
            """Is it worth attempting a request?"""
301
            return self._usable
302

    
303
        def finished(self):
304
            """Did the request succeed or encounter permanent failure?"""
305
            return self._result['error'] == False or not self._usable
306

    
307
        def last_result(self):
308
            return self._result
309

    
310
        def _get_user_agent(self):
311
            try:
312
                return self._user_agent_pool.get(block=False)
313
            except queue.Empty:
314
                return pycurl.Curl()
315

    
316
        def _put_user_agent(self, ua):
317
            try:
318
                ua.reset()
319
                self._user_agent_pool.put(ua, block=False)
320
            except:
321
                ua.close()
322

    
323
        @staticmethod
324
        def _socket_open(family, socktype, protocol, address=None):
325
            """Because pycurl doesn't have CURLOPT_TCP_KEEPALIVE"""
326
            s = socket.socket(family, socktype, protocol)
327
            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
328
            # Will throw invalid protocol error on mac. This test prevents that.
329
            if hasattr(socket, 'TCP_KEEPIDLE'):
330
                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 75)
331
            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 75)
332
            return s
333

    
334
        def get(self, locator, method="GET", timeout=None):
335
            # locator is a KeepLocator object.
336
            url = self.root + str(locator)
337
            _logger.debug("Request: %s %s", method, url)
338
            curl = self._get_user_agent()
339
            ok = None
340
            try:
341
                with timer.Timer() as t:
342
                    self._headers = {}
343
                    response_body = BytesIO()
344
                    curl.setopt(pycurl.NOSIGNAL, 1)
345
                    curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
346
                    curl.setopt(pycurl.URL, url.encode('utf-8'))
347
                    curl.setopt(pycurl.HTTPHEADER, [
348
                        '{}: {}'.format(k,v) for k,v in self.get_headers.items()])
349
                    curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
350
                    curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
351
                    if method == "HEAD":
352
                        curl.setopt(pycurl.NOBODY, True)
353
                    self._setcurltimeouts(curl, timeout)
354

    
355
                    try:
356
                        curl.perform()
357
                    except Exception as e:
358
                        raise arvados.errors.HttpError(0, str(e))
359
                    self._result = {
360
                        'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
361
                        'body': response_body.getvalue(),
362
                        'headers': self._headers,
363
                        'error': False,
364
                    }
365

    
366
                ok = retry.check_http_response_success(self._result['status_code'])
367
                if not ok:
368
                    self._result['error'] = arvados.errors.HttpError(
369
                        self._result['status_code'],
370
                        self._headers.get('x-status-line', 'Error'))
371
            except self.HTTP_ERRORS as e:
372
                self._result = {
373
                    'error': e,
374
                }
375
            self._usable = ok != False
376
            if self._result.get('status_code', None):
377
                # The client worked well enough to get an HTTP status
378
                # code, so presumably any problems are just on the
379
                # server side and it's OK to reuse the client.
380
                self._put_user_agent(curl)
381
            else:
382
                # Don't return this client to the pool, in case it's
383
                # broken.
384
                curl.close()
385
            if not ok:
386
                _logger.debug("Request fail: GET %s => %s: %s",
387
                              url, type(self._result['error']), str(self._result['error']))
388
                return None
389
            if method == "HEAD":
390
                _logger.info("HEAD %s: %s bytes",
391
                         self._result['status_code'],
392
                         self._result.get('content-length'))
393
                return True
394

    
395
            _logger.info("GET %s: %s bytes in %s msec (%.3f MiB/sec)",
396
                         self._result['status_code'],
397
                         len(self._result['body']),
398
                         t.msecs,
399
                         1.0*len(self._result['body'])/2**20/t.secs if t.secs > 0 else 0)
400

    
401
            if self.download_counter:
402
                self.download_counter.add(len(self._result['body']))
403
            resp_md5 = hashlib.md5(self._result['body']).hexdigest()
404
            if resp_md5 != locator.md5sum:
405
                _logger.warning("Checksum fail: md5(%s) = %s",
406
                                url, resp_md5)
407
                self._result['error'] = arvados.errors.HttpError(
408
                    0, 'Checksum fail')
409
                return None
410
            return self._result['body']
411

    
412
        def put(self, hash_s, body, timeout=None):
413
            url = self.root + hash_s
414
            _logger.debug("Request: PUT %s", url)
415
            curl = self._get_user_agent()
416
            ok = None
417
            try:
418
                with timer.Timer() as t:
419
                    self._headers = {}
420
                    body_reader = BytesIO(body)
421
                    response_body = BytesIO()
422
                    curl.setopt(pycurl.NOSIGNAL, 1)
423
                    curl.setopt(pycurl.OPENSOCKETFUNCTION, self._socket_open)
424
                    curl.setopt(pycurl.URL, url.encode('utf-8'))
425
                    # Using UPLOAD tells cURL to wait for a "go ahead" from the
426
                    # Keep server (in the form of a HTTP/1.1 "100 Continue"
427
                    # response) instead of sending the request body immediately.
428
                    # This allows the server to reject the request if the request
429
                    # is invalid or the server is read-only, without waiting for
430
                    # the client to send the entire block.
431
                    curl.setopt(pycurl.UPLOAD, True)
432
                    curl.setopt(pycurl.INFILESIZE, len(body))
433
                    curl.setopt(pycurl.READFUNCTION, body_reader.read)
434
                    curl.setopt(pycurl.HTTPHEADER, [
435
                        '{}: {}'.format(k,v) for k,v in self.put_headers.items()])
436
                    curl.setopt(pycurl.WRITEFUNCTION, response_body.write)
437
                    curl.setopt(pycurl.HEADERFUNCTION, self._headerfunction)
438
                    self._setcurltimeouts(curl, timeout)
439
                    try:
440
                        curl.perform()
441
                    except Exception as e:
442
                        raise arvados.errors.HttpError(0, str(e))
443
                    self._result = {
444
                        'status_code': curl.getinfo(pycurl.RESPONSE_CODE),
445
                        'body': response_body.getvalue().decode('utf-8'),
446
                        'headers': self._headers,
447
                        'error': False,
448
                    }
449
                ok = retry.check_http_response_success(self._result['status_code'])
450
                if not ok:
451
                    self._result['error'] = arvados.errors.HttpError(
452
                        self._result['status_code'],
453
                        self._headers.get('x-status-line', 'Error'))
454
            except self.HTTP_ERRORS as e:
455
                self._result = {
456
                    'error': e,
457
                }
458
            self._usable = ok != False # still usable if ok is True or None
459
            if self._result.get('status_code', None):
460
                # Client is functional. See comment in get().
461
                self._put_user_agent(curl)
462
            else:
463
                curl.close()
464
            if not ok:
465
                _logger.debug("Request fail: PUT %s => %s: %s",
466
                              url, type(self._result['error']), str(self._result['error']))
467
                return False
468
            _logger.info("PUT %s: %s bytes in %s msec (%.3f MiB/sec)",
469
                         self._result['status_code'],
470
                         len(body),
471
                         t.msecs,
472
                         1.0*len(body)/2**20/t.secs if t.secs > 0 else 0)
473
            if self.upload_counter:
474
                self.upload_counter.add(len(body))
475
            return True
476

    
477
        def _setcurltimeouts(self, curl, timeouts):
478
            if not timeouts:
479
                return
480
            elif isinstance(timeouts, tuple):
481
                if len(timeouts) == 2:
482
                    conn_t, xfer_t = timeouts
483
                    bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
484
                else:
485
                    conn_t, xfer_t, bandwidth_bps = timeouts
486
            else:
487
                conn_t, xfer_t = (timeouts, timeouts)
488
                bandwidth_bps = KeepClient.DEFAULT_TIMEOUT[2]
489
            curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(conn_t*1000))
490
            curl.setopt(pycurl.LOW_SPEED_TIME, int(math.ceil(xfer_t)))
491
            curl.setopt(pycurl.LOW_SPEED_LIMIT, int(math.ceil(bandwidth_bps)))
492

    
493
        def _headerfunction(self, header_line):
494
            if isinstance(header_line, bytes):
495
                header_line = header_line.decode('iso-8859-1')
496
            if ':' in header_line:
497
                name, value = header_line.split(':', 1)
498
                name = name.strip().lower()
499
                value = value.strip()
500
            elif self._headers:
501
                name = self._lastheadername
502
                value = self._headers[name] + ' ' + header_line.strip()
503
            elif header_line.startswith('HTTP/'):
504
                name = 'x-status-line'
505
                value = header_line
506
            else:
507
                _logger.error("Unexpected header line: %s", header_line)
508
                return
509
            self._lastheadername = name
510
            self._headers[name] = value
511
            # Returning None implies all bytes were written
512
    
513

    
514
    class KeepWriterQueue(queue.Queue):
515
        def __init__(self, copies):
516
            queue.Queue.__init__(self) # Old-style superclass
517
            self.wanted_copies = copies
518
            self.successful_copies = 0
519
            self.response = None
520
            self.successful_copies_lock = threading.Lock()
521
            self.pending_tries = copies
522
            self.pending_tries_notification = threading.Condition()
523
        
524
        def write_success(self, response, replicas_nr):
525
            with self.successful_copies_lock:
526
                self.successful_copies += replicas_nr
527
                self.response = response
528
            with self.pending_tries_notification:
529
                self.pending_tries_notification.notify_all()
530
        
531
        def write_fail(self, ks):
532
            with self.pending_tries_notification:
533
                self.pending_tries += 1
534
                self.pending_tries_notification.notify()
535
        
536
        def pending_copies(self):
537
            with self.successful_copies_lock:
538
                return self.wanted_copies - self.successful_copies
539

    
540
        def get_next_task(self):
541
            with self.pending_tries_notification:
542
                while True:
543
                    if self.pending_copies() < 1:
544
                        # This notify_all() is unnecessary --
545
                        # write_success() already called notify_all()
546
                        # when pending<1 became true, so it's not
547
                        # possible for any other thread to be in
548
                        # wait() now -- but it's cheap insurance
549
                        # against deadlock so we do it anyway:
550
                        self.pending_tries_notification.notify_all()
551
                        # Drain the queue and then raise Queue.Empty
552
                        while True:
553
                            self.get_nowait()
554
                            self.task_done()
555
                    elif self.pending_tries > 0:
556
                        service, service_root = self.get_nowait()
557
                        if service.finished():
558
                            self.task_done()
559
                            continue
560
                        self.pending_tries -= 1
561
                        return service, service_root
562
                    elif self.empty():
563
                        self.pending_tries_notification.notify_all()
564
                        raise queue.Empty
565
                    else:
566
                        self.pending_tries_notification.wait()
567

    
568

    
569
    class KeepWriterThreadPool(object):
570
        def __init__(self, data, data_hash, copies, max_service_replicas, timeout=None):
571
            self.total_task_nr = 0
572
            self.wanted_copies = copies
573
            if (not max_service_replicas) or (max_service_replicas >= copies):
574
                num_threads = 1
575
            else:
576
                num_threads = int(math.ceil(1.0*copies/max_service_replicas))
577
            _logger.debug("Pool max threads is %d", num_threads)
578
            self.workers = []
579
            self.queue = KeepClient.KeepWriterQueue(copies)
580
            # Create workers
581
            for _ in range(num_threads):
582
                w = KeepClient.KeepWriterThread(self.queue, data, data_hash, timeout)
583
                self.workers.append(w)
584
        
585
        def add_task(self, ks, service_root):
586
            self.queue.put((ks, service_root))
587
            self.total_task_nr += 1
588
        
589
        def done(self):
590
            return self.queue.successful_copies
591
        
592
        def join(self):
593
            # Start workers
594
            for worker in self.workers:
595
                worker.start()
596
            # Wait for finished work
597
            self.queue.join()
598
        
599
        def response(self):
600
            return self.queue.response
601
    
602
    
603
    class KeepWriterThread(threading.Thread):
604
        TaskFailed = RuntimeError()
605

    
606
        def __init__(self, queue, data, data_hash, timeout=None):
607
            super(KeepClient.KeepWriterThread, self).__init__()
608
            self.timeout = timeout
609
            self.queue = queue
610
            self.data = data
611
            self.data_hash = data_hash
612
            self.daemon = True
613

    
614
        def run(self):
615
            while True:
616
                try:
617
                    service, service_root = self.queue.get_next_task()
618
                except queue.Empty:
619
                    return
620
                try:
621
                    locator, copies = self.do_task(service, service_root)
622
                except Exception as e:
623
                    if e is not self.TaskFailed:
624
                        _logger.exception("Exception in KeepWriterThread")
625
                    self.queue.write_fail(service)
626
                else:
627
                    self.queue.write_success(locator, copies)
628
                finally:
629
                    self.queue.task_done()
630

    
631
        def do_task(self, service, service_root):
632
            success = bool(service.put(self.data_hash,
633
                                        self.data,
634
                                        timeout=self.timeout))
635
            result = service.last_result()
636

    
637
            if not success:
638
                if result.get('status_code', None):
639
                    _logger.debug("Request fail: PUT %s => %s %s",
640
                                  self.data_hash,
641
                                  result['status_code'],
642
                                  result['body'])
643
                raise self.TaskFailed
644

    
645
            _logger.debug("KeepWriterThread %s succeeded %s+%i %s",
646
                          str(threading.current_thread()),
647
                          self.data_hash,
648
                          len(self.data),
649
                          service_root)
650
            try:
651
                replicas_stored = int(result['headers']['x-keep-replicas-stored'])
652
            except (KeyError, ValueError):
653
                replicas_stored = 1
654

    
655
            return result['body'].strip(), replicas_stored
656

    
657

    
658
    def __init__(self, api_client=None, proxy=None,
659
                 timeout=DEFAULT_TIMEOUT, proxy_timeout=DEFAULT_PROXY_TIMEOUT,
660
                 api_token=None, local_store=None, block_cache=None,
661
                 num_retries=0, session=None):
662
        """Initialize a new KeepClient.
663

664
        Arguments:
665
        :api_client:
666
          The API client to use to find Keep services.  If not
667
          provided, KeepClient will build one from available Arvados
668
          configuration.
669

670
        :proxy:
671
          If specified, this KeepClient will send requests to this Keep
672
          proxy.  Otherwise, KeepClient will fall back to the setting of the
673
          ARVADOS_KEEP_SERVICES or ARVADOS_KEEP_PROXY configuration settings.
674
          If you want to KeepClient does not use a proxy, pass in an empty
675
          string.
676

677
        :timeout:
678
          The initial timeout (in seconds) for HTTP requests to Keep
679
          non-proxy servers.  A tuple of three floats is interpreted as
680
          (connection_timeout, read_timeout, minimum_bandwidth). A connection
681
          will be aborted if the average traffic rate falls below
682
          minimum_bandwidth bytes per second over an interval of read_timeout
683
          seconds. Because timeouts are often a result of transient server
684
          load, the actual connection timeout will be increased by a factor
685
          of two on each retry.
686
          Default: (2, 256, 32768).
687

688
        :proxy_timeout:
689
          The initial timeout (in seconds) for HTTP requests to
690
          Keep proxies. A tuple of three floats is interpreted as
691
          (connection_timeout, read_timeout, minimum_bandwidth). The behavior
692
          described above for adjusting connection timeouts on retry also
693
          applies.
694
          Default: (20, 256, 32768).
695

696
        :api_token:
697
          If you're not using an API client, but only talking
698
          directly to a Keep proxy, this parameter specifies an API token
699
          to authenticate Keep requests.  It is an error to specify both
700
          api_client and api_token.  If you specify neither, KeepClient
701
          will use one available from the Arvados configuration.
702

703
        :local_store:
704
          If specified, this KeepClient will bypass Keep
705
          services, and save data to the named directory.  If unspecified,
706
          KeepClient will fall back to the setting of the $KEEP_LOCAL_STORE
707
          environment variable.  If you want to ensure KeepClient does not
708
          use local storage, pass in an empty string.  This is primarily
709
          intended to mock a server for testing.
710

711
        :num_retries:
712
          The default number of times to retry failed requests.
713
          This will be used as the default num_retries value when get() and
714
          put() are called.  Default 0.
715
        """
716
        self.lock = threading.Lock()
717
        if proxy is None:
718
            if config.get('ARVADOS_KEEP_SERVICES'):
719
                proxy = config.get('ARVADOS_KEEP_SERVICES')
720
            else:
721
                proxy = config.get('ARVADOS_KEEP_PROXY')
722
        if api_token is None:
723
            if api_client is None:
724
                api_token = config.get('ARVADOS_API_TOKEN')
725
            else:
726
                api_token = api_client.api_token
727
        elif api_client is not None:
728
            raise ValueError(
729
                "can't build KeepClient with both API client and token")
730
        if local_store is None:
731
            local_store = os.environ.get('KEEP_LOCAL_STORE')
732

    
733
        self.block_cache = block_cache if block_cache else KeepBlockCache()
734
        self.timeout = timeout
735
        self.proxy_timeout = proxy_timeout
736
        self._user_agent_pool = queue.LifoQueue()
737
        self.upload_counter = Counter()
738
        self.download_counter = Counter()
739
        self.put_counter = Counter()
740
        self.get_counter = Counter()
741
        self.hits_counter = Counter()
742
        self.misses_counter = Counter()
743

    
744
        if local_store:
745
            self.local_store = local_store
746
            self.get = self.local_store_get
747
            self.put = self.local_store_put
748
        else:
749
            self.num_retries = num_retries
750
            self.max_replicas_per_service = None
751
            if proxy:
752
                proxy_uris = proxy.split()
753
                for i in range(len(proxy_uris)):
754
                    if not proxy_uris[i].endswith('/'):
755
                        proxy_uris[i] += '/'
756
                    # URL validation
757
                    url = urllib.parse.urlparse(proxy_uris[i])
758
                    if not (url.scheme and url.netloc):
759
                        raise arvados.errors.ArgumentError("Invalid proxy URI: {}".format(proxy_uris[i]))
760
                self.api_token = api_token
761
                self._gateway_services = {}
762
                self._keep_services = [{
763
                    'uuid': "00000-bi6l4-%015d" % idx,
764
                    'service_type': 'proxy',
765
                    '_service_root': uri,
766
                    } for idx, uri in enumerate(proxy_uris)]
767
                self._writable_services = self._keep_services
768
                self.using_proxy = True
769
                self._static_services_list = True
770
            else:
771
                # It's important to avoid instantiating an API client
772
                # unless we actually need one, for testing's sake.
773
                if api_client is None:
774
                    api_client = arvados.api('v1')
775
                self.api_client = api_client
776
                self.api_token = api_client.api_token
777
                self._gateway_services = {}
778
                self._keep_services = None
779
                self._writable_services = None
780
                self.using_proxy = None
781
                self._static_services_list = False
782

    
783
    def current_timeout(self, attempt_number):
784
        """Return the appropriate timeout to use for this client.
785

786
        The proxy timeout setting if the backend service is currently a proxy,
787
        the regular timeout setting otherwise.  The `attempt_number` indicates
788
        how many times the operation has been tried already (starting from 0
789
        for the first try), and scales the connection timeout portion of the
790
        return value accordingly.
791

792
        """
793
        # TODO(twp): the timeout should be a property of a
794
        # KeepService, not a KeepClient. See #4488.
795
        t = self.proxy_timeout if self.using_proxy else self.timeout
796
        if len(t) == 2:
797
            return (t[0] * (1 << attempt_number), t[1])
798
        else:
799
            return (t[0] * (1 << attempt_number), t[1], t[2])
800
    def _any_nondisk_services(self, service_list):
801
        return any(ks.get('service_type', 'disk') != 'disk'
802
                   for ks in service_list)
803

    
804
    def build_services_list(self, force_rebuild=False):
805
        if (self._static_services_list or
806
              (self._keep_services and not force_rebuild)):
807
            return
808
        with self.lock:
809
            try:
810
                keep_services = self.api_client.keep_services().accessible()
811
            except Exception:  # API server predates Keep services.
812
                keep_services = self.api_client.keep_disks().list()
813

    
814
            # Gateway services are only used when specified by UUID,
815
            # so there's nothing to gain by filtering them by
816
            # service_type.
817
            self._gateway_services = {ks['uuid']: ks for ks in
818
                                      keep_services.execute()['items']}
819
            if not self._gateway_services:
820
                raise arvados.errors.NoKeepServersError()
821

    
822
            # Precompute the base URI for each service.
823
            for r in self._gateway_services.values():
824
                host = r['service_host']
825
                if not host.startswith('[') and host.find(':') >= 0:
826
                    # IPv6 URIs must be formatted like http://[::1]:80/...
827
                    host = '[' + host + ']'
828
                r['_service_root'] = "{}://{}:{:d}/".format(
829
                    'https' if r['service_ssl_flag'] else 'http',
830
                    host,
831
                    r['service_port'])
832

    
833
            _logger.debug(str(self._gateway_services))
834
            self._keep_services = [
835
                ks for ks in self._gateway_services.values()
836
                if not ks.get('service_type', '').startswith('gateway:')]
837
            self._writable_services = [ks for ks in self._keep_services
838
                                       if not ks.get('read_only')]
839

    
840
            # For disk type services, max_replicas_per_service is 1
841
            # It is unknown (unlimited) for other service types.
842
            if self._any_nondisk_services(self._writable_services):
843
                self.max_replicas_per_service = None
844
            else:
845
                self.max_replicas_per_service = 1
846

    
847
    def _service_weight(self, data_hash, service_uuid):
848
        """Compute the weight of a Keep service endpoint for a data
849
        block with a known hash.
850

851
        The weight is md5(h + u) where u is the last 15 characters of
852
        the service endpoint's UUID.
853
        """
854
        return hashlib.md5((data_hash + service_uuid[-15:]).encode()).hexdigest()
855

    
856
    def weighted_service_roots(self, locator, force_rebuild=False, need_writable=False):
857
        """Return an array of Keep service endpoints, in the order in
858
        which they should be probed when reading or writing data with
859
        the given hash+hints.
860
        """
861
        self.build_services_list(force_rebuild)
862

    
863
        sorted_roots = []
864
        # Use the services indicated by the given +K@... remote
865
        # service hints, if any are present and can be resolved to a
866
        # URI.
867
        for hint in locator.hints:
868
            if hint.startswith('K@'):
869
                if len(hint) == 7:
870
                    sorted_roots.append(
871
                        "https://keep.{}.arvadosapi.com/".format(hint[2:]))
872
                elif len(hint) == 29:
873
                    svc = self._gateway_services.get(hint[2:])
874
                    if svc:
875
                        sorted_roots.append(svc['_service_root'])
876

    
877
        # Sort the available local services by weight (heaviest first)
878
        # for this locator, and return their service_roots (base URIs)
879
        # in that order.
880
        use_services = self._keep_services
881
        if need_writable:
882
            use_services = self._writable_services
883
        self.using_proxy = self._any_nondisk_services(use_services)
884
        sorted_roots.extend([
885
            svc['_service_root'] for svc in sorted(
886
                use_services,
887
                reverse=True,
888
                key=lambda svc: self._service_weight(locator.md5sum, svc['uuid']))])
889
        _logger.debug("{}: {}".format(locator, sorted_roots))
890
        return sorted_roots
891

    
892
    def map_new_services(self, roots_map, locator, force_rebuild, need_writable, **headers):
893
        # roots_map is a dictionary, mapping Keep service root strings
894
        # to KeepService objects.  Poll for Keep services, and add any
895
        # new ones to roots_map.  Return the current list of local
896
        # root strings.
897
        headers.setdefault('Authorization', "OAuth2 %s" % (self.api_token,))
898
        local_roots = self.weighted_service_roots(locator, force_rebuild, need_writable)
899
        for root in local_roots:
900
            if root not in roots_map:
901
                roots_map[root] = self.KeepService(
902
                    root, self._user_agent_pool,
903
                    upload_counter=self.upload_counter,
904
                    download_counter=self.download_counter,
905
                    **headers)
906
        return local_roots
907

    
908
    @staticmethod
909
    def _check_loop_result(result):
910
        # KeepClient RetryLoops should save results as a 2-tuple: the
911
        # actual result of the request, and the number of servers available
912
        # to receive the request this round.
913
        # This method returns True if there's a real result, False if
914
        # there are no more servers available, otherwise None.
915
        if isinstance(result, Exception):
916
            return None
917
        result, tried_server_count = result
918
        if (result is not None) and (result is not False):
919
            return True
920
        elif tried_server_count < 1:
921
            _logger.info("No more Keep services to try; giving up")
922
            return False
923
        else:
924
            return None
925

    
926
    def get_from_cache(self, loc):
927
        """Fetch a block only if is in the cache, otherwise return None."""
928
        slot = self.block_cache.get(loc)
929
        if slot is not None and slot.ready.is_set():
930
            return slot.get()
931
        else:
932
            return None
933

    
934
    @retry.retry_method
935
    def head(self, loc_s, num_retries=None):
936
        return self._get_or_head(loc_s, method="HEAD", num_retries=num_retries)
937

    
938
    @retry.retry_method
939
    def get(self, loc_s, num_retries=None):
940
        return self._get_or_head(loc_s, method="GET", num_retries=num_retries)
941

    
942
    def _get_or_head(self, loc_s, method="GET", num_retries=None):
943
        """Get data from Keep.
944

945
        This method fetches one or more blocks of data from Keep.  It
946
        sends a request each Keep service registered with the API
947
        server (or the proxy provided when this client was
948
        instantiated), then each service named in location hints, in
949
        sequence.  As soon as one service provides the data, it's
950
        returned.
951

952
        Arguments:
953
        * loc_s: A string of one or more comma-separated locators to fetch.
954
          This method returns the concatenation of these blocks.
955
        * num_retries: The number of times to retry GET requests to
956
          *each* Keep server if it returns temporary failures, with
957
          exponential backoff.  Note that, in each loop, the method may try
958
          to fetch data from every available Keep service, along with any
959
          that are named in location hints in the locator.  The default value
960
          is set when the KeepClient is initialized.
961
        """
962
        if ',' in loc_s:
963
            return ''.join(self.get(x) for x in loc_s.split(','))
964

    
965
        self.get_counter.add(1)
966

    
967
        locator = KeepLocator(loc_s)
968
        if method == "GET":
969
            slot, first = self.block_cache.reserve_cache(locator.md5sum)
970
            if not first:
971
                self.hits_counter.add(1)
972
                v = slot.get()
973
                return v
974

    
975
        self.misses_counter.add(1)
976

    
977
        # If the locator has hints specifying a prefix (indicating a
978
        # remote keepproxy) or the UUID of a local gateway service,
979
        # read data from the indicated service(s) instead of the usual
980
        # list of local disk services.
981
        hint_roots = ['http://keep.{}.arvadosapi.com/'.format(hint[2:])
982
                      for hint in locator.hints if hint.startswith('K@') and len(hint) == 7]
983
        hint_roots.extend([self._gateway_services[hint[2:]]['_service_root']
984
                           for hint in locator.hints if (
985
                                   hint.startswith('K@') and
986
                                   len(hint) == 29 and
987
                                   self._gateway_services.get(hint[2:])
988
                                   )])
989
        # Map root URLs to their KeepService objects.
990
        roots_map = {
991
            root: self.KeepService(root, self._user_agent_pool,
992
                                   upload_counter=self.upload_counter,
993
                                   download_counter=self.download_counter)
994
            for root in hint_roots
995
        }
996

    
997
        # See #3147 for a discussion of the loop implementation.  Highlights:
998
        # * Refresh the list of Keep services after each failure, in case
999
        #   it's being updated.
1000
        # * Retry until we succeed, we're out of retries, or every available
1001
        #   service has returned permanent failure.
1002
        sorted_roots = []
1003
        roots_map = {}
1004
        blob = None
1005
        loop = retry.RetryLoop(num_retries, self._check_loop_result,
1006
                               backoff_start=2)
1007
        for tries_left in loop:
1008
            try:
1009
                sorted_roots = self.map_new_services(
1010
                    roots_map, locator,
1011
                    force_rebuild=(tries_left < num_retries),
1012
                    need_writable=False)
1013
            except Exception as error:
1014
                loop.save_result(error)
1015
                continue
1016

    
1017
            # Query KeepService objects that haven't returned
1018
            # permanent failure, in our specified shuffle order.
1019
            services_to_try = [roots_map[root]
1020
                               for root in sorted_roots
1021
                               if roots_map[root].usable()]
1022
            for keep_service in services_to_try:
1023
                blob = keep_service.get(locator, method=method, timeout=self.current_timeout(num_retries-tries_left))
1024
                if blob is not None:
1025
                    break
1026
            loop.save_result((blob, len(services_to_try)))
1027

    
1028
        # Always cache the result, then return it if we succeeded.
1029
        if method == "GET":
1030
            slot.set(blob)
1031
            self.block_cache.cap_cache()
1032
        if loop.success():
1033
            if method == "HEAD":
1034
                return True
1035
            else:
1036
                return blob
1037

    
1038
        # Q: Including 403 is necessary for the Keep tests to continue
1039
        # passing, but maybe they should expect KeepReadError instead?
1040
        not_founds = sum(1 for key in sorted_roots
1041
                         if roots_map[key].last_result().get('status_code', None) in {403, 404, 410})
1042
        service_errors = ((key, roots_map[key].last_result()['error'])
1043
                          for key in sorted_roots)
1044
        if not roots_map:
1045
            raise arvados.errors.KeepReadError(
1046
                "failed to read {}: no Keep services available ({})".format(
1047
                    loc_s, loop.last_result()))
1048
        elif not_founds == len(sorted_roots):
1049
            raise arvados.errors.NotFoundError(
1050
                "{} not found".format(loc_s), service_errors)
1051
        else:
1052
            raise arvados.errors.KeepReadError(
1053
                "failed to read {}".format(loc_s), service_errors, label="service")
1054

    
1055
    @retry.retry_method
1056
    def put(self, data, copies=2, num_retries=None):
1057
        """Save data in Keep.
1058

1059
        This method will get a list of Keep services from the API server, and
1060
        send the data to each one simultaneously in a new thread.  Once the
1061
        uploads are finished, if enough copies are saved, this method returns
1062
        the most recent HTTP response body.  If requests fail to upload
1063
        enough copies, this method raises KeepWriteError.
1064

1065
        Arguments:
1066
        * data: The string of data to upload.
1067
        * copies: The number of copies that the user requires be saved.
1068
          Default 2.
1069
        * num_retries: The number of times to retry PUT requests to
1070
          *each* Keep server if it returns temporary failures, with
1071
          exponential backoff.  The default value is set when the
1072
          KeepClient is initialized.
1073
        """
1074

    
1075
        if not isinstance(data, bytes):
1076
            data = data.encode()
1077

    
1078
        self.put_counter.add(1)
1079

    
1080
        data_hash = hashlib.md5(data).hexdigest()
1081
        loc_s = data_hash + '+' + str(len(data))
1082
        if copies < 1:
1083
            return loc_s
1084
        locator = KeepLocator(loc_s)
1085

    
1086
        headers = {}
1087
        # Tell the proxy how many copies we want it to store
1088
        headers['X-Keep-Desired-Replicas'] = str(copies)
1089
        roots_map = {}
1090
        loop = retry.RetryLoop(num_retries, self._check_loop_result,
1091
                               backoff_start=2)
1092
        done = 0
1093
        for tries_left in loop:
1094
            try:
1095
                sorted_roots = self.map_new_services(
1096
                    roots_map, locator,
1097
                    force_rebuild=(tries_left < num_retries), need_writable=True, **headers)
1098
            except Exception as error:
1099
                loop.save_result(error)
1100
                continue
1101

    
1102
            writer_pool = KeepClient.KeepWriterThreadPool(data=data, 
1103
                                                        data_hash=data_hash,
1104
                                                        copies=copies - done,
1105
                                                        max_service_replicas=self.max_replicas_per_service,
1106
                                                        timeout=self.current_timeout(num_retries - tries_left))
1107
            for service_root, ks in [(root, roots_map[root])
1108
                                     for root in sorted_roots]:
1109
                if ks.finished():
1110
                    continue
1111
                writer_pool.add_task(ks, service_root)
1112
            writer_pool.join()
1113
            done += writer_pool.done()
1114
            loop.save_result((done >= copies, writer_pool.total_task_nr))
1115

    
1116
        if loop.success():
1117
            return writer_pool.response()
1118
        if not roots_map:
1119
            raise arvados.errors.KeepWriteError(
1120
                "failed to write {}: no Keep services available ({})".format(
1121
                    data_hash, loop.last_result()))
1122
        else:
1123
            service_errors = ((key, roots_map[key].last_result()['error'])
1124
                              for key in sorted_roots
1125
                              if roots_map[key].last_result()['error'])
1126
            raise arvados.errors.KeepWriteError(
1127
                "failed to write {} (wanted {} copies but wrote {})".format(
1128
                    data_hash, copies, writer_pool.done()), service_errors, label="service")
1129

    
1130
    def local_store_put(self, data, copies=1, num_retries=None):
1131
        """A stub for put().
1132

1133
        This method is used in place of the real put() method when
1134
        using local storage (see constructor's local_store argument).
1135

1136
        copies and num_retries arguments are ignored: they are here
1137
        only for the sake of offering the same call signature as
1138
        put().
1139

1140
        Data stored this way can be retrieved via local_store_get().
1141
        """
1142
        md5 = hashlib.md5(data).hexdigest()
1143
        locator = '%s+%d' % (md5, len(data))
1144
        with open(os.path.join(self.local_store, md5 + '.tmp'), 'w') as f:
1145
            f.write(data)
1146
        os.rename(os.path.join(self.local_store, md5 + '.tmp'),
1147
                  os.path.join(self.local_store, md5))
1148
        return locator
1149

    
1150
    def local_store_get(self, loc_s, num_retries=None):
1151
        """Companion to local_store_put()."""
1152
        try:
1153
            locator = KeepLocator(loc_s)
1154
        except ValueError:
1155
            raise arvados.errors.NotFoundError(
1156
                "Invalid data locator: '%s'" % loc_s)
1157
        if locator.md5sum == config.EMPTY_BLOCK_LOCATOR.split('+')[0]:
1158
            return ''
1159
        with open(os.path.join(self.local_store, locator.md5sum), 'r') as f:
1160
            return f.read()
1161

    
1162
    def is_cached(self, locator):
1163
        return self.block_cache.reserve_cache(expect_hash)