Skiplist スキップリスト: 並行処理編 Part. 2

ソースはGitHubに移行しました。

SkiplistのLock-Free版はいくつか提案されている:

またjavaは言語レベルでサポート(ConcurrentSkipListMap)するようである。


Lock-Free Skiplist(単方向 Lock-Free スキップリスト)

ここでは"A Lock-Free concurrent skiplist with wait-free search" を実装した。
今のところ32-bit版Linuxディストリビューション上でのみ動く。Intel64(X86_64)環境には対応していない。

ソースはGitHubに移行しました。

データ構造

目新しいのはpthread_keyを使った作業領域確保の部分か。 普通の実装は作業領域は一つで済むが、 Lock-Free版は複数スレッドが同時に動作するのでスレッド毎に作業領域が必要になる。

LockFreeSkiplist.h:

typedef enum {MARKED = 0, UNMARKED = 1} node_stat;

typedef struct _tower_ref {
  node_stat mark;
  struct _skiplist_node_t *next_node_ptr;
}__attribute__((packed)) tower_ref;


typedef struct _skiplist_node_t {
  lkey_t key;
  val_t val;
  int topLevel;
  tower_ref *tower;

} skiplist_node_t;

typedef struct _skiplist_t {
  int	maxLevel;

  skiplist_node_t *head;
  skiplist_node_t *tail;

  pthread_key_t workspace_key;
} skiplist_t;


typedef struct _workspace_t {
  skiplist_node_t **preds;
  skiplist_node_t **succs;
} workspace_t;

領域確保の関数群はconcurrent_skiplist.hに記述してある。

concurrent_skiplist.h:
/* 
 * static workspace_t *init_workspace(const int maxLevel)
 *
 * (スレッド毎に)skiplist *slの補助メモリ領域:predsとsuccsを確保する。
 *
 * 成功すると確保したメモリへのポインタを返す。失敗するとNULLを返す。
 */
static workspace_t *init_workspace(const int maxLevel)
{
    workspace_t *ws;

    if ((ws = (workspace_t *) calloc(1, sizeof(workspace_t))) == NULL) {
	elog("calloc error");
	return NULL;
    }

    if ((ws->preds = (skiplist_node_t **) calloc(1, sizeof(skiplist_node_t *) * maxLevel)) == NULL) {
	elog("calloc error");
	goto end;
    }
    if ((ws->succs = (skiplist_node_t **) calloc(1, sizeof(skiplist_node_t *) * maxLevel)) == NULL) {
	elog("calloc error");
	goto end;
    }

    return ws;

  end:
    free(ws->preds);
    free(ws);
    return (workspace_t *) NULL;
}

/*
 * void free_workspace(workspace_t * ws)
 *
 * wsを開放する。
 */
static void free_workspace(workspace_t * ws)
{
    free(ws->preds);
    free(ws->succs);
    free(ws);
}

/*
 * workspace_t *get_workspace(skiplist_t * sl)
 *
 * (スレッド毎に)初めて呼ばれた時はメモリ領域へ確保し、pthread_keyに対応させる。
 * 必ずスレッドに対応したメモリ領域へのポインタを返す。
 *
 */
static workspace_t *get_workspace(skiplist_t * sl)
{
  /* スレッド毎のメモリ領域へのポインタを得る */
  workspace_t *workspace = pthread_getspecific(sl->workspace_key);

  /* まだメモリ領域が確保されていない場合: */
  if (workspace == NULL) {
    /* メモリ領域の確保 */
    if ((workspace = init_workspace(sl->maxLevel)) != NULL) { 
      /* メモリ領域をpthread_keyに対応させる */
      if (pthread_setspecific(sl->workspace_key, (void *) workspace) != 0) {
	elog("pthread_setspecific() error");
	abort();
      }
    } else {
      elog("init_workspace() error");
      abort();
    }
  }
  assert(workspace != NULL);
  return workspace;
}

基本関数

Lock-Freeのためのプリミティブ

ポインタとmarkをatomicに操作するため、32bit+32bitの構造体tower_refを定義し、値の更新はcmpxchg8bで行う。
ポインタサイズを4byte=32bitと仮定してコーディングしているので、 Intel64(X86_64)環境で動かすにはcmpxchg8bからcmpxchgqに変えるだけでなく、データ構造そのものに多少変更が必要。
近々に対応したいが、興味を持って先に実装してしまった人がいたら、是非連絡ください。

typedef struct _tower_ref {
  node_stat mark;
  struct _skiplist_node_t *next_node_ptr;
}__attribute__((packed)) tower_ref;


inline bool_t cas64(volatile tower_ref * addr, const tower_ref oldp,
		  const tower_ref newp)
{
    char result;
    __asm__ __volatile__("lock; cmpxchg8b %0; setz %1":"=m"(*(volatile tower_ref *)addr),
			 "=q"(result)
			 :"m"(*(volatile tower_ref *)addr), "a"(oldp.mark),
			 "d"(oldp.next_node_ptr), "b"(newp.mark),
			 "c"(newp.next_node_ptr)
			 :"memory");

    return (((int) result != 0) ? true : false);
}

static tower_ref make_ref(const skiplist_node_t * next_node_ptr,
			  const node_stat mark)
{
    tower_ref ref;
    ref.mark = mark;
    ref.next_node_ptr = (skiplist_node_t *) next_node_ptr;
    return ref;
}
listの生成
skiplist_t *init_list(const int maxLevel, const lkey_t min, const lkey_t max)
{
    int i;
    skiplist_t *sl;
    skiplist_node_t *head, *tail;


    if ((sl = (skiplist_t *) calloc(1, sizeof(skiplist_t))) == NULL) {
      elog("calloc error");
      return NULL;
    }

    sl->maxLevel = maxLevel;

    if ((head = create_node(maxLevel, min, min)) == NULL) {
      elog("create_node() error");
      goto end;
    }

    if ((tail = create_node(maxLevel, max, max)) == NULL) {
      elog("create_node() error");
      goto end;
    }

    for (i = 0; i < maxLevel; i++) {
	head->tower[i].next_node_ptr = tail;
	tail->tower[i].next_node_ptr = NULL;
    }

    sl->head = head;
    sl->tail = tail;

    if (pthread_key_create(&sl->workspace_key, (void *) free_workspace) != 0) {
	elog("pthread_key_create() error");
	abort();
    }


    return sl;
 end:
    free(sl->head);
    free(sl);
    return NULL;
}


void free_skiplist(skiplist_t * sl)
{
    free_node(sl->head);
    free_node(sl->tail);
    free(sl);
}
ノードの検索
static bool_t search(skiplist_t * sl, const lkey_t key, skiplist_node_t ** preds,
		 skiplist_node_t ** succs)
{
    int level;
    skiplist_node_t *pred, *curr, *succ;
    bool_t snip;
    node_stat marked;

  retry:
    while (1) {
	pred = sl->head;

	for (level = sl->maxLevel - 1; level >= 0; level--) {
	    curr = pred->tower[level].next_node_ptr;

	    while (1) {
		succ = curr->tower[level].next_node_ptr;
		marked = curr->tower[level].mark;

		while (marked == MARKED) {
		    snip =
			cas64(&(*pred).tower[level],
			      make_ref(curr, UNMARKED), make_ref(succ,
								 UNMARKED));

		    if (snip != true)
			goto retry;

		    curr = pred->tower[level].next_node_ptr;
		    succ = curr->tower[level].next_node_ptr;
		    marked = curr->tower[level].mark;
		}

		if (key > curr->key) {
		    pred = curr;
		    curr = succ;
		} else
		    break;
	    }			// end while(1)

	    preds[level] = pred;
	    succs[level] = curr;
	}			// end for()

	return (curr->key == key ? true : false);
    }				// end while(1)

    return true;		// dummy
}

nodeの追加

各スレッドはadd()を呼び出す。add()内部では、pthread_keyで管理されている作業領域wsを取り出して、 実際の追加を行う_add()を呼ぶ。

bool_t add(skiplist_t * sl, const lkey_t key, const val_t val)
{
    workspace_t *ws = get_workspace(sl);
    assert(ws != NULL);
    return _add(sl, ws->preds, ws->succs, key, val);
}


static bool_t _add(skiplist_t * sl, skiplist_node_t ** preds,
	 skiplist_node_t ** succs, const lkey_t key, const val_t val)
{
    int level, topLevel;
    int bottomLevel = 0;
    skiplist_node_t *pred, *succ, *newNode;
    int r = rand();

    topLevel = (r % sl->maxLevel);
    assert(0 <= topLevel && topLevel < sl->maxLevel);

    while (1) {
	if (search(sl, key, preds, succs) == true)
	    return false;

	newNode = create_node(topLevel, key, val);

	for (level = bottomLevel; level <= topLevel; level++) {
	    newNode->tower[level].next_node_ptr = succs[level];
	    newNode->tower[level].mark = UNMARKED;
	}

	pred = preds[bottomLevel];
	succ = succs[bottomLevel];
	newNode->tower[bottomLevel].next_node_ptr = succ;
	newNode->tower[bottomLevel].mark = UNMARKED;

	if (cas64
	    (&(*pred).tower[bottomLevel], make_ref(succ, UNMARKED),
	     make_ref(newNode, UNMARKED))
	    == false)
	    continue;

	for (level = bottomLevel + 1; level <= topLevel; level++) {
	    while (1) {
		pred = preds[level];
		succ = succs[level];
		if (cas64
		    (&(*pred).tower[level], make_ref(succ, UNMARKED),
		     make_ref(newNode, UNMARKED))
		    == true)
		    break;
		search(sl, key, preds, succs);
	    }
	}
	return true;
    }
    return true;		// dummy
}

関数_add()内部で呼ぶnode生成関数create_node()を示す。

#define node_size(level)	  (sizeof(skiplist_node_t) + (level * sizeof(skiplist_node_t *)))

static skiplist_node_t *create_node(const int topLevel, const lkey_t key, const val_t val)
{
    skiplist_node_t *node;
    int level;

    if ((node = (skiplist_node_t *) calloc(1, node_size(topLevel))) == NULL) {
      elog("calloc error");
      return NULL;
    }

    node->key = key;
    node->val = val;
    node->topLevel = topLevel;

    if ((node->tower =
	 (tower_ref *) calloc(1, sizeof(tower_ref) * (topLevel + 1))) ==
	NULL) {
      elog("calloc error");
	free(node);
	return NULL;
    }
    for (level = 0; level < topLevel; level++) {
	node->tower[level].next_node_ptr = NULL;
	node->tower[level].mark = UNMARKED;
    }

    return node;
}
nodeの削除
bool_t delete(skiplist_t * sl, lkey_t key, val_t * val)
{
    workspace_t *ws = get_workspace(sl);
    assert(ws != NULL);
    return _delete(sl, ws->preds, ws->succs, key, val);
}


static bool_t _delete(skiplist_t * sl, skiplist_node_t ** preds,
	    skiplist_node_t ** succs, const lkey_t key, val_t * val)
{
    int level;
    int bottomLevel = 0;
    skiplist_node_t *succ, *victim;
    node_stat marked;

    while (1) {
	if (search(sl, key, preds, succs) == false)
	    return false;

	victim = succs[bottomLevel];

	for (level = victim->topLevel; level >= bottomLevel + 1; level--) {
	    succ = victim->tower[level].next_node_ptr;
	    marked = victim->tower[level].mark;

	    while (marked == UNMARKED) {
		victim->tower[level].next_node_ptr = succ;
		victim->tower[level].mark = MARKED;

		succ = victim->tower[level].next_node_ptr;
		marked = victim->tower[level].mark;
	    }
	}			// end for()

	succ = victim->tower[bottomLevel].next_node_ptr;
	marked = victim->tower[bottomLevel].mark;

	while (1) {
	    bool_t iMarkedIt;
	    iMarkedIt =
		cas64(&(*victim).tower[bottomLevel],
		      make_ref(succ, UNMARKED), make_ref(succ, MARKED));

	    succ = succs[bottomLevel]->tower[bottomLevel].next_node_ptr;
	    marked = succs[bottomLevel]->tower[bottomLevel].mark;

	    if (iMarkedIt == true) {
		search(sl, key, preds, succs);
		*val = victim->val;
		free_node(victim);
		return true;
	    } else if (marked == MARKED)
		return false;
	}

    }				// end while()
    return true;		// dummy
}
nodeの検索
val_t find(skiplist_t * sl, lkey_t key)
{
    workspace_t *ws = get_workspace(sl);
    assert(ws != NULL);
    return _find(sl, ws->preds, ws->succs, key);
}

static bool_t _find(skiplist_t * sl, skiplist_node_t ** preds,
	      skiplist_node_t ** succs, const lkey_t key)
{
    int level;
    int bottomLevel = 0;
    skiplist_node_t *pred, *curr, *succ;
    node_stat marked = false;


    pred = sl->head;
    curr = NULL;
    succ = NULL;

    for (level = sl->maxLevel; level >= bottomLevel; level--) {
	curr = pred->tower[level].next_node_ptr;

	while (1) {
	    succ = curr->tower[level].next_node_ptr;
	    marked = curr->tower[level].mark;

	    while (marked == MARKED) {
		curr = pred->tower[level].next_node_ptr;
		succ = curr->tower[level].next_node_ptr;
		marked = curr->tower[level].mark;
	    }

	    if (curr->key < key) {
		pred = curr;
		curr = succ;
	    } else
		break;
	}			// end while(1)
    }

    return (curr->key == key ? true : false);
}

実行

ソースはGitHubに移行しました。



Last-modified: 2014-7-6