最近,我们班助教下发了新学期数据结构课程的大作业。像大作业这么喜闻乐见,愉悦身心的东西,当然要想方设法水完她了~于是便有了今天这样一篇文章。

题目要求就是实现STL中的3个容器:优先队列,更加优秀的双端队列,有序映射。
其实就是封装三个非常经典的数据结构:可并堆,块状数组,平衡树。
先从最为简单的优先队列开始说起:

这道题给了5s的时限,1024m的内存,测试点只有5个,平均每个1s。
估算数据范围不会很大,卡常问题不会严重。于是我选择最为简单的:左偏树。
左偏树无非就是向左偏的树(废话),在key满足二叉堆性质的同时,满足左节点到空节点的最小距离>=右节点到空节点的最小距离。
实现的话,完全基于合并,也没什么好说的。
合并无非就是判空决定是否停止、判key决定合并方向。
压入为让根和新节点合并,弹出则是将根替换成两个孩子的合并。
一开始由于maintain的时候没有判断左孩子为空则左右孩子互换而导致被卡成链表,优化后轻松通过。
#ifndef SJTU_PRIORITY_QUEUE_HPP
#define SJTU_PRIORITY_QUEUE_HPP
#include <cstddef>
#include <functional>
#include "exceptions.hpp"
/*#include <iostream>
#define debug cout
using namespace std;*/
namespace sjtu {
/**
* a container like std::priority_queue which is a heap internal.
*/
template<typename T, class Compare = std::less<T> >
class priority_queue {
private:
struct Node {
T val;
Node *ls, *rs;
int dis;
Node(const T &_val): val(_val), ls(nullptr), rs(nullptr), dis(0) {}
void maintain() {
if(ls == nullptr || (rs != nullptr && ls->dis < rs->dis)) std::swap(ls, rs);
dis = (rs == nullptr ? -1 : rs->dis) + 1;
}
}*root;
inline void deleteAll(Node* pos) {
if(pos == nullptr) return;
deleteAll(pos->ls), deleteAll(pos->rs);
delete pos;
}
inline Node* copy(const Node* x) {
if(x == nullptr) return nullptr;
Node* ret = new Node(x->val);
ret->ls = copy(x->ls);
ret->rs = copy(x->rs);
ret->maintain();
return ret;
}
inline Node* merge(Node* a, Node* b) {
if(a == nullptr || b == nullptr) return b == nullptr ? a : b;
if(Compare()(a->val, b->val)) std::swap(a, b);
a->rs = merge(a->rs, b);
a->maintain();
return a;
}
size_t _size;
public:
priority_queue(): _size(0) {
root = nullptr;
}
priority_queue(const priority_queue &other): _size(other._size) {
root = copy(other.root);
}
~priority_queue() {
deleteAll(root);
}
priority_queue &operator=(const priority_queue &other) {
if(this == &other) return *this;
deleteAll(root);
root = copy(other.root);
_size = other._size;
return *this;
}
const T & top() const {
if(root == nullptr) throw container_is_empty();
return root->val;
}
void push(const T &e) {
++_size;
Node* nv = new Node(e);
root = merge(nv, root);
}
void pop() {
if(root == nullptr) throw container_is_empty();
--_size;
Node* mem = root;
root = merge(root->ls, root->rs);
delete mem;
}
size_t size() const {
return _size;
}
bool empty() const {
return root == nullptr;
}
void merge(priority_queue &other) {
_size += other._size;
root = merge(root, other.root);
other.root = nullptr, other._size = 0;
}
};
}
#endif
提交记录:

这玩意竟然还跑得挺快。夭寿啦~我萌点没啦~
下面来看看第二题:双端队列。
由于要求了迭代器的存在,这道题目所需的代码量要大出不少。再加上分块细节本来就多,导致此题愈发难写。
我选择的分块方式为链表套数组(而不是更加好写的链表套链表),随机访问无非就是先跳链表再数组寻址。元素过多则分裂、元素偏移中心过多则搬运元素。
迭代器的话,由于要求O(1)访问,所以不能只记录序列位置,每次都从头开始跳,而是要用指针保存所在的块、指向的位置。+-的话直接从向后、向前寻址。(毕竟你不能把++的复杂度变成根号的对吧)
于是我就这样写了,但是却并不能通过编译。因为有些类没有默认构造函数,你不能直接new出来。
那好办,再封一层指针就好了,用的时候再new嘛。
Test2的memory怎么Fail了?
STL要求放进去的每个成员只能在被加入时构造一次,被删除时析构一次,且这个测试点专门检查它。
好吧,在搬运元素的时候直接指针移位,原位赋值nullptr。这回总归可以了。
最坑的是test3的鲁棒性测试,要求解决越界访问、失效迭代器的问题。
越界访问,判即可。越界迭代器?由于该测试点考察越界情况较为简单,所以只需要判断迭代器记录的全局位置是否越界。
于是就有了这一堆奇奇怪怪的代码:
#ifndef SJTU_DEQUE_HPP
#define SJTU_DEQUE_HPP
#include "exceptions.hpp"
#include <cstddef>
// #define DEBUG
#ifdef DEBUG // debugging
#include <cassert>
#include <iostream>
#define debug cerr
using namespace std;
#endif
namespace sjtu {
constexpr int BlockSiz = 1700;
#ifndef DEBUG
constexpr int iniSt = 301;
constexpr int maxSiz = 1000; // if size == 1000, then split.
constexpr int maxSt = 601;
constexpr int minSt = 101;
#else
constexpr int iniSt = 4;
constexpr int maxSiz = 10; // if size == 1000, then split.
constexpr int maxSt = 6;
constexpr int minSt = 3;
#endif
template<class T>
class deque {
public:
class iterator;
class const_iterator;
private:
struct DataType {
T* dat;
DataType(): dat(nullptr) {}
DataType(const DataType &x) {dat = x.dat;}
explicit DataType(const T &x) {dat = new T(x);}
~DataType() {if(dat != nullptr) { delete dat; }}
DataType& operator = (DataType &&x) { if(this != &x) dat = x.dat, x.dat = nullptr; return *this; }
DataType& operator = (const std::nullptr_t &nptr) { return dat = nptr, *this; }
};
struct Block {
DataType* dat;
int st, ed; // visit st for the first element, ed for the last element.
Block *prv, *nxt;
Block():st(iniSt), ed(iniSt - 1), prv(nullptr), nxt(nullptr) { dat = new DataType[BlockSiz]; }
Block(const DataType* src, const int siz): st(iniSt), ed(iniSt + siz - 1), prv(nullptr), nxt(nullptr)
{ dat = new DataType[BlockSiz]; for(int i = 0; i < siz; i++) dat[st + i].dat = src[i].dat; } // start from src[0]
Block(const Block &b): st(b.st), ed(b.ed), prv(nullptr), nxt(nullptr)
{ dat = new DataType[BlockSiz]; for(int i = st; i <= ed; i++) dat[i] = DataType(*b.dat[i].dat); }
~Block() { delete[] dat; };
int size() { return ed - st + 1; }
friend void moveNext(Block* &cur, DataType* &tar, int n) {
int cat = tar - cur->dat;
while(cat + n > cur->ed) {
if(!cur->nxt->size()) break; // out of range.
n -= cur->ed - cat;
cur = cur->nxt;
tar = cur->dat + cur->st - 1;
cat = tar - cur->dat;
}
tar += n;
}
friend void movePrev(Block* &cur, DataType* &tar, int n) {
int cat = tar - cur->dat;
while(cat - n < cur->st) {
if(!cur->prv->size()) break; // out of range.
n -= cat - cur->st;
cur = cur->prv;
tar = cur->dat + cur->ed + 1;
cat = tar - cur->dat;
}
tar -= n;
}
friend void moveNext(const Block* &cur, const DataType* &tar, int n) {
int cat = tar - cur->dat;
while(cat + n > cur->ed) {
if(!cur->nxt->size()) break; // out of range.
n -= cur->ed - cat;
cur = cur->nxt;
tar = cur->dat + cur->st - 1;
cat = tar - cur->dat;
}
tar += n;
}
friend void movePrev(const Block* &cur, const DataType* &tar, int n) {
int cat = tar - cur->dat;
while(cat - n < cur->st) {
if(!cur->prv->size()) break; // out of range.
n -= cat - cur->st;
cur = cur->prv;
tar = cur->dat + cur->ed + 1;
cat = tar - cur->dat;
}
tar -= n;
}
void movEle() {
int nst = iniSt, ned = nst + size() - 1;
if(st > maxSt) for(int i = 0; i < size(); i++) dat[nst + i].dat = dat[st + i].dat, dat[st + i] = nullptr;
else for(int i = size() - 1; ~i; i--) dat[nst + i].dat = dat[st + i].dat, dat[st + i] = nullptr;
st = nst, ed = ned;
}
void trySplit() {
if(size() < maxSiz) {
if(st > maxSt || st < minSt) movEle();
return;
}
const int siz1 = size() / 2, siz2 = size() - siz1;
Block *n1 = new Block(dat + st, siz1), *n2 = new Block(dat + st + siz1, siz2);
n1->prv = prv, n1->nxt = n2, n2->prv = n1, n2->nxt = nxt;
#ifdef DEBUG
assert(prv != nullptr && nxt != nullptr);
#endif
prv->nxt = n1, nxt->prv = n2;
for(int i = st; i <= ed; i++) dat[i] = nullptr;
delete this;
}
void tryRemove() {
if(size()) return;
if(prv != nullptr && prv == nxt) {
#ifdef DEBUG
assert(prv->nxt == this && nxt->prv == this);
#endif
return; // don't delete last real block.
}
if(prv != nullptr) prv->nxt = nxt;
if(nxt != nullptr) nxt->prv = prv;
delete this;
}
void removeKth(int k) {
k += st - 1, --ed;
delete dat[k].dat;
for(int i = k; i <= ed; i++) dat[i].dat = dat[i + 1].dat;
dat[ed + 1] = nullptr;
tryRemove();
}
void insertKth(const T &v, int k) {
k += st - 1, ++ed;
for(int i = ed; i > k; i--) dat[i].dat = dat[i - 1].dat;
dat[k] = DataType(v), trySplit();
}
void push_front(const T &x) { dat[--st] = DataType(x), trySplit(); }
void push_back(const T &x) { dat[++ed] = DataType(x), trySplit(); }
void pop_front() { delete dat[st].dat; dat[st++] = nullptr, tryRemove(); }
void pop_back() { delete dat[ed].dat; dat[ed--] = nullptr, tryRemove(); }
}root; // root -> nxt is the head, root -> prv is the tail.
void deleteAll() {
auto p = root.nxt;
while(p != nullptr && p != &root) {
auto p2 = p->nxt;
delete p;
p = p2;
}
root.nxt = root.prv = nullptr;
}
void checkRoot() {
if(root.nxt == nullptr) {
#ifdef DEBUG
assert(root.prv == nullptr);
#endif
root.nxt = root.prv = new Block;
root.prv->nxt = &root, root.nxt->prv = &root;
}
}
void copyAll(const Block &root2) {
if(root2.nxt == nullptr) {
#ifdef DEBUG
assert(root2.prv == nullptr);
#endif
return;
}
root.nxt = new Block(*root2.nxt), root.nxt->prv = &root;
auto cur = root.nxt, cur2 = root2.nxt;
while(cur2->nxt != &root2) {
cur->nxt = new Block(*cur2->nxt), cur->nxt->prv = cur;
cur = cur->nxt, cur2 = cur2->nxt;
}
root.prv = cur, cur->nxt = &root;
}
T& accessKth(int n) {
++n;
auto p = root.nxt;
while(n > p->size()) n -= p->size(), p = p->nxt;
return *p->dat[p->st + n - 1].dat;
}
const T& accessKth(int n) const {
++n;
auto p = root.nxt;
while(n > p->size()) n -= p->size(), p = p->nxt;
return *p->dat[p->st + n - 1].dat;
}
int fullSiz;
iterator iteratorKth(int n) {
const int nn = n;
auto p = root.nxt;
while(n > p->size()) n -= p->size(), p = p->nxt;
return iterator(this, p, p->dat + p->st + n - 1, nn);
}
void insertKth(int n, const T &v) {
checkRoot();
auto p = root.nxt;
while(n > p->size() + 1) n -= p->size(), p = p->nxt;
p->insertKth(v, n);
}
void removeKth(int n) {
auto p = root.nxt;
while(n > p->size()) n -= p->size(), p = p->nxt;
p->removeKth(n);
}
bool checkAccessIterator(const iterator &it) const {
if(it.id > size() || it.id < 1) return 0;
return 1;
}
bool checkAccessIterator(const const_iterator &it) const {
if(it.id > size() || it.id < 1) return 0;
return 1;
}
public:
class iterator {
private:
Block* blk;
DataType* tar;
public:
deque* fa;
int id;
iterator() = default;
iterator(deque* _fa, Block* _blk, DataType* _tar, int _id): fa(_fa), blk(_blk), tar(_tar), id(_id) {}
iterator operator + (const int &n) const { auto ret = *this; ret.id += n, n >= 0 ? moveNext(ret.blk, ret.tar, n) : movePrev(ret.blk, ret.tar, -n); return ret; }
iterator operator - (const int &n) const { auto ret = *this; ret.id -= n, n >= 0 ? movePrev(ret.blk, ret.tar, n) : moveNext(ret.blk, ret.tar, -n); return ret; }
int operator - (const iterator &rhs) const { if(fa != rhs.fa) throw invalid_iterator(); else return id - rhs.id; }
iterator& operator += (const int &n) { return *this = *this + n; }
iterator& operator -= (const int &n) { return *this = *this - n; }
iterator operator ++ (int) { auto ret = *this; return *this = *this + 1, ret; }
iterator& operator ++ () { return *this = *this + 1; }
iterator operator -- (int) { auto ret = *this; return *this = *this - 1, ret; }
iterator& operator -- () { return *this = *this - 1; }
T& operator * () { if(!fa->checkAccessIterator(*this)) throw invalid_iterator(); else return *tar->dat; }
const T& operator * () const { if(!fa->checkAccessIterator(*this)) throw invalid_iterator(); return *tar->dat; }
T* operator -> () const { if(!fa->checkAccessIterator(*this)) throw invalid_iterator(); return tar->dat; }
bool operator == (const iterator &rhs) const { return fa == rhs.fa && blk == rhs.blk && tar == rhs.tar && id == rhs.id; }
bool operator == (const const_iterator &rhs) const { return fa == rhs.fa && blk == rhs.blk && tar == rhs.tar && id == rhs.id; }
bool operator != (const iterator &rhs) const { return !(*this == rhs); }
bool operator != (const const_iterator &rhs) const { return !(*this == rhs); }
};
class const_iterator {
private:
const Block* blk;
const DataType* tar;
public:
const deque* fa;
int id;
const_iterator(): fa(nullptr), blk(nullptr), tar(nullptr), id(-1) {}
const_iterator(const deque* _fa, const Block* _blk, const DataType* _tar, int _id): fa(_fa), blk(_blk), tar(_tar), id(_id) {}
const_iterator(const const_iterator &other): fa(other.fa), blk(other.blk), tar(other.tar), id(other.id) {}
const_iterator(const iterator &other): fa(other.fa), blk(other.blk), tar(other.tar), id(other.id) {}
const_iterator operator + (const int &n) const { auto ret = *this; ret.id += n, n >= 0 ? moveNext(ret.blk, ret.tar, n) : movePrev(ret.blk, ret.tar, -n); return ret; }
const_iterator operator - (const int &n) const { auto ret = *this; ret.id -= n, n >= 0 ? movePrev(ret.blk, ret.tar, n) : moveNext(ret.blk, ret.tar, -n); return ret; }
int operator - (const const_iterator &rhs) const { if(fa != rhs.fa) throw invalid_iterator(); else return id - rhs.id; }
const_iterator& operator += (const int &n) { return *this = *this + n; }
const_iterator& operator -= (const int &n) { return *this = *this - n; }
const_iterator operator ++ (int) { auto ret = *this; return *this = *this + 1, ret; }
const_iterator& operator ++ () { return *this = *this + 1; }
const_iterator operator -- (int) { auto ret = *this; return *this = *this - 1, ret; }
const_iterator& operator -- () { return *this = *this - 1; }
const T& operator * () const { if(!fa->checkAccessIterator(*this)) throw invalid_iterator(); else return *tar->dat; }
const T* operator -> () const noexcept { if(!fa->checkAccessIterator(*this)) throw invalid_iterator(); else return tar->dat; }
bool operator == (const iterator &rhs) const { return fa == rhs.fa && blk == rhs.blk && tar == rhs.tar && id == rhs.id; }
bool operator == (const const_iterator &rhs) const { return fa == rhs.fa && blk == rhs.blk && tar == rhs.tar && id == rhs.id; }
bool operator != (const iterator &rhs) const { return !(*this == rhs); }
bool operator != (const const_iterator &rhs) const { return !(*this == rhs); }
};
deque(): fullSiz(0) { checkRoot(); }
deque(const deque &other): fullSiz(other.fullSiz) { copyAll(other.root); }
~deque() { deleteAll(); }
deque &operator=(const deque &other) { if(&other != this) deleteAll(), copyAll(other.root), fullSiz = other.fullSiz; return *this; }
T & at(const size_t &pos) { if(pos >= size() || pos < 0) throw index_out_of_bound(); else return accessKth(pos); }
const T & at(const size_t &pos) const { if(pos >= size() || pos < 0) throw index_out_of_bound(); else return accessKth(pos); }
T & operator[] (const size_t &pos) { if(pos >= size() || pos < 0) throw index_out_of_bound(); else return accessKth(pos); }
const T & operator[] (const size_t &pos) const { if(pos >= size() || pos < 0) throw index_out_of_bound(); else return accessKth(pos); }
const T & front() const { if(empty()) throw container_is_empty(); else return *root.nxt->dat[root.nxt->st].dat; }
const T & back() const { if(empty()) throw container_is_empty(); else return *root.prv->dat[root.prv->ed].dat; }
iterator begin() { return iterator(this, root.nxt, root.nxt->dat + root.nxt->st, 1); }
const_iterator cbegin() const { return const_iterator(this, root.nxt, root.nxt->dat + root.nxt->st, 1); }
iterator end() { return iterator(this, root.prv, root.prv->dat + root.prv->ed + 1, size() + 1); }
const_iterator cend() const { return const_iterator(this, root.prv, root.prv->dat + root.prv->ed + 1, size() + 1); }
bool empty() const { return size() == 0; }
size_t size() const { return fullSiz; }
void clear() { deleteAll(), checkRoot(), fullSiz = 0; }
iterator insert(iterator pos, const T &value) {
if(pos.fa != this) throw invalid_iterator();
if(size_t(pos.id) > size() + 1) throw invalid_iterator();
insertKth(pos.id, value), ++fullSiz;
return iteratorKth(pos.id);
}
iterator erase(iterator pos) {
if(pos.fa != this) throw invalid_iterator();
if(size_t(pos.id) > size()) throw container_is_empty();
removeKth(pos.id), --fullSiz;
return size_t(pos.id) <= size() ? iteratorKth(pos.id) : end();
}
void push_back(const T &value) { checkRoot(), ++fullSiz, root.prv->push_back(value); }
void pop_back() { if(empty()) throw container_is_empty(); else --fullSiz, root.prv->pop_back(); }
void push_front(const T &value) { checkRoot(), ++fullSiz, root.nxt->push_front(value); }
void pop_front() { if(empty()) throw container_is_empty(); else --fullSiz, root.nxt->pop_front(); }
};
}
#endif
压行有什么错?压行那么可爱,为什么要迫害压行~

于是~这玩意就慢如爪巴了qwq
最后的(也是最喜欢的)映射——也就是我们最为熟悉的平衡树啦~
助教说去年大部分人写的是红黑树,不过我这种寒假没有提前写的人当然是不愿意写红黑树的(还不是因为懒)。于是我便选择了:替罪羊树。
平衡树的实现没什么好啰嗦的。
除了一开始因为删除有两个孩子的节点时直接替代删除前驱导致前驱迭代器失效外,没有任何问题。
当然,除了慢。
/**
* implement a container like std::map
*/
#ifndef SJTU_MAP_HPP
#define SJTU_MAP_HPP
// only for std::less<T>
#include <functional>
#include <cstddef>
#include "utility.hpp"
#include "exceptions.hpp"
namespace sjtu {
constexpr double ALPHA = 0.73;
template<class Key, class T, class Compare = std::less<Key> >
class map {
public:
class iterator;
class const_iterator;
typedef pair<const Key, T> value_type;
private:
bool cmp(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, b->first);
}
bool equal(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return a == b;
return !Compare()(a->first, b->first) && !Compare()(b->first, a->first);
}
bool cmp(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, *b);
}
bool equal(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return 0;
return !Compare()(a->first, *b) && !Compare()(*b, a->first);
}
struct Node {
value_type* v;
Node *ls, *rs, *fa;
int siz;
Node(value_type* _v = nullptr): v(_v), ls(nullptr), rs(nullptr), fa(nullptr), siz(1) {}
Node(const Node &oth): ls(nullptr), rs(nullptr), fa(nullptr), siz(oth.siz) { v = oth.v == nullptr ? nullptr : new value_type(*oth.v); }
~Node() { delete v; }
void maintain() { siz = (ls ? ls->siz : 0) + (rs ? rs->siz : 0) + 1; }
void reset() { ls = rs = fa = nullptr, siz = 1; }
}*root;
void dfs(Node** const dst, int& cnt, Node* const x) {
if(x->ls) dfs(dst, cnt, x->ls);
dst[++cnt] = x;
if(x->rs) dfs(dst, cnt, x->rs);
x->reset();
}
Node* rebuild(Node** const src, const int l, const int r) {
const int mid = (l + r) >> 1;
Node* ret = src[mid];
if(l < mid) ret->ls = rebuild(src, l, mid - 1), ret->ls->fa = ret;
if(mid < r) ret->rs = rebuild(src, mid + 1, r), ret->rs->fa = ret;
ret->maintain();
return ret;
}
Node *fail, *fail_fa;
void checkFail(Node* x) {
while(x) {
if(x->ls && x->ls->siz > x->siz * ALPHA) fail_fa = x->fa, fail = x;
if(x->rs && x->rs->siz > x->siz * ALPHA) fail_fa = x->fa, fail = x;
x = x->fa;
}
}
void checkRebuild(Node* const cur) {
fail = fail_fa = nullptr, checkFail(cur);
if(fail) {
Node** temp = new Node* [fail->siz + 3];
int cnt = 0;
dfs(temp, cnt, fail);
auto t = (fail_fa == nullptr ? root : (fail == fail_fa->ls ? fail_fa->ls : fail_fa->rs)) = rebuild(temp, 1, cnt);
t->fa = fail_fa; // !
delete[] temp;
}
}
pair<iterator, bool> insert(value_type* const v) {
Node* cur = root;
while(1) {
if(equal(cur->v, v)) {
delete v;
return pair<iterator, bool>(iterator(this, cur), 0); // todo: return an iterator.
}
if(cmp(cur->v, v)) {
if(cur->rs) cur = cur->rs;
else {
cur->rs = new Node(v), cur->rs->fa = cur;
cur->maintain(), cur = cur->rs;
break;
}
} else {
if(cur->ls) cur = cur->ls;
else {
cur->ls = new Node(v), cur->ls->fa = cur;
cur->maintain(), cur = cur->ls;
break;
}
}
}
fixChain(cur), checkRebuild(cur);
return pair<iterator, bool>(iterator(this, cur), 1);
}
void fixChain(Node* pos) {
while(pos) {
pos->maintain();
pos = pos->fa;
}
}
void erase(Node* pos) {
if(pos->ls == nullptr && pos->rs == nullptr) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = nullptr;
fixChain(pos->fa), checkRebuild(pos->fa);
delete pos;
} else {
if(pos->ls == nullptr || pos->rs == nullptr) {
Node* son = pos->ls ? pos->ls : pos->rs;
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
fixChain(pos->fa), checkRebuild(pos->fa);
delete pos;
} else {
Node *son = pos->ls;
while (son->rs) son = son->rs;
if(son->fa != pos) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son;
else root = son;
(son == son->fa->ls ? son->fa->ls : son->fa->rs) = pos;
std::swap(pos->ls, son->ls), std::swap(pos->rs, son->rs), std::swap(pos->fa, son->fa);
if(pos->ls) pos->ls->fa = pos; if(pos->rs) pos->rs->fa = pos;
if(son->ls) son->ls->fa = son; if(son->rs) son->rs->fa = son;
} else {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
const auto son_ls = son->ls, son_rs = son->rs;
(son == pos->ls ? son->ls : son->rs) = pos, pos->fa = son;
if((pos == son->ls ? (son->rs = pos->rs) : (son->ls = pos->ls))) (pos == son->ls ? son->rs : son->ls)->fa = son;
if((pos->ls = son_ls)) pos->ls->fa = pos;
if((pos->rs = son_rs)) pos->rs->fa = pos;
}
erase(pos);
}
}
}
Node* find(const Key* tar) const {
Node* cur = root;
while(cur) {
if(equal(cur->v, tar)) return cur;
if(cmp(cur->v, tar)) cur = cur->rs;
else cur = cur->ls;
}
return nullptr;
}
Node* findPrv(Node* pos) { // return nullptr when failed.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(Node* pos) {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
Node* findPrv(const Node* pos) const { // const version of previous two functions.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(const Node* pos) const {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
void deleteAll(Node* pos) {
if(pos == nullptr) return;
if(pos->ls) deleteAll(pos->ls);
if(pos->rs) deleteAll(pos->rs);
delete pos;
}
Node* copyAll(Node* cur) {
if(cur == nullptr) return nullptr;
Node* ret = new Node(*cur);
if(cur->ls) ret->ls = copyAll(cur->ls), ret->ls->fa = ret;
if(cur->rs) ret->rs = copyAll(cur->rs), ret->rs->fa = ret;
return ret;
}
Node* nodeBegin() const {
Node* cur = root;
while(cur->ls) cur = cur->ls;
return cur;
}
Node* nodeEnd() const {
Node* cur = root;
while(cur->rs) cur = cur->rs;
return cur;
}
public:
class iterator {
public:
map* bel;
Node* tar;
iterator(map* _bel = nullptr, Node* _tar = nullptr): bel(_bel), tar(_tar) {}
iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
value_type* operator->() const noexcept { return tar->v; }
};
class const_iterator {
public:
const map* bel;
const Node* tar;
const_iterator(const map* _bel = nullptr, const Node* _tar = nullptr): bel(_bel), tar(_tar) {}
const_iterator(const const_iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const_iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
const value_type* operator->() const noexcept { return tar->v; }
};
map() { root = new Node(); }
map(const map &other) { root = copyAll(other.root); }
map & operator=(const map &other) { if(this !=&other) deleteAll(root), root = copyAll(other.root); return *this; }
~map() { deleteAll(root); }
T & at(const Key &key) { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
const T & at(const Key &key) const { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
T & operator[](const Key &key) {
Node* tar = find(&key);
if(tar == nullptr) {
value_type* nv = new value_type(key, T());
tar = insert(nv).first.tar;
}
return tar->v->second;
}
const T & operator[](const Key &key) const { return at(key); }
iterator begin() { return iterator(this, nodeBegin()); }
const_iterator cbegin() const { return const_iterator(this, nodeBegin()); }
iterator end() { return iterator(this, nodeEnd()); }
const_iterator cend() const { return const_iterator(this, nodeEnd()); }
bool empty() const { return size() == 0; }
size_t size() const { return root->siz - 1; }
void clear() { deleteAll(root), root = new Node(); }
pair<iterator, bool> insert(const value_type &value) { value_type* nv = new value_type(value); return insert(nv); }
void erase(iterator pos) { if(pos.bel != this || pos.tar->v == nullptr) throw invalid_iterator(); else erase(pos.tar); }
size_t count(const Key &key) const { auto tar = find(&key); return tar != nullptr; }
iterator find(const Key &key) { auto tar = find(&key); return tar == nullptr ? end() : iterator(this, tar); }
const_iterator find(const Key &key) const { auto tar = find(&key); return tar == nullptr ? cend() : const_iterator(this, tar); }
};
}
#endif

替罪羊树喜闻乐见地TLE辣~虽然据说是zui快的不平衡树,但和拥有主动平衡能力的重量级选手们还是有不少的差距的。
至少,第7个测试点跑5s这一点,不可接受。
于是,我转向了:Size Balanced Tree。
只需要修改几个函数,替罪羊秒变SBT~
/**
* implement a container like std::map
*/
#ifndef SJTU_MAP_HPP
#define SJTU_MAP_HPP
// only for std::less<T>
#include <functional>
#include <cstddef>
#include "utility.hpp"
#include "exceptions.hpp"
namespace sjtu {
template<class Key, class T, class Compare = std::less<Key> >
class map {
public:
class iterator;
class const_iterator;
typedef pair<const Key, T> value_type;
private:
bool cmp(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, b->first);
}
bool equal(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return a == b;
return !Compare()(a->first, b->first) && !Compare()(b->first, a->first);
}
bool cmp(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, *b);
}
bool equal(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return 0;
return !Compare()(a->first, *b) && !Compare()(*b, a->first);
}
struct Node {
value_type* v;
Node *ls, *rs, *fa;
int siz;
Node(value_type* _v = nullptr): v(_v), ls(nullptr), rs(nullptr), fa(nullptr), siz(1) {}
Node(const Node &oth): ls(nullptr), rs(nullptr), fa(nullptr), siz(oth.siz) { v = oth.v == nullptr ? nullptr : new value_type(*oth.v); }
~Node() { delete v; }
void maintain() { siz = (ls ? ls->siz : 0) + (rs ? rs->siz : 0) + 1; }
void reset() { ls = rs = fa = nullptr, siz = 1; }
}*root;
void rotate(Node* pos) {
Node* const fa = pos->fa;
if(fa->fa) (fa == fa->fa->ls ? fa->fa->ls : fa->fa->rs) = pos, pos->fa = fa->fa;
else root = pos, pos->fa = nullptr;
if(pos == fa->ls) {
fa->ls = pos->rs;
if(fa->ls) fa->ls->fa = fa;
pos->rs = fa, fa->fa = pos;
} else {
fa->rs = pos->ls;
if(fa->rs) fa->rs->fa = fa;
pos->ls = fa, fa->fa = pos;
}
fa->maintain(), pos->maintain();
}
void maintain(Node* pos) {
if(pos->ls == nullptr || pos->rs == nullptr) {
if(pos->ls == nullptr && pos->rs == nullptr) return;
if(pos->ls != nullptr) {
if(pos->ls->ls) {
const auto l = pos->ls;
rotate(l);
maintain(pos);
maintain(l);
} else if(pos->ls->rs) {
const auto l = pos->ls, b = pos->ls->rs;
rotate(b);
rotate(b);
maintain(l);
maintain(pos);
maintain(b);
}
} else {
if(pos->rs->rs) {
const auto l = pos->rs;
rotate(l);
maintain(pos);
maintain(l);
} else if(pos->rs->ls) {
const auto l = pos->rs, b = pos->rs->ls;
rotate(b);
rotate(b);
maintain(l);
maintain(pos);
maintain(b);
}
}
return;
}
if(pos->ls->ls && pos->ls->ls->siz > pos->rs->siz) {
const auto l = pos->ls;
rotate(l);
maintain(pos);
maintain(l);
} else if(pos->ls->rs && pos->ls->rs->siz > pos->rs->siz) {
const auto l = pos->ls, b = pos->ls->rs;
rotate(b);
rotate(b);
maintain(l);
maintain(pos);
maintain(b);
} else if(pos->rs->rs && pos->rs->rs->siz > pos->ls->siz) {
const auto l = pos->rs;
rotate(l);
maintain(pos);
maintain(l);
} else if(pos->rs->ls && pos->rs->ls->siz > pos->ls->siz) {
const auto l = pos->rs, b = pos->rs->ls;
rotate(b);
rotate(b);
maintain(l);
maintain(pos);
maintain(b);
}
}
bool flag;
Node* insert(Node* const pos, value_type* const v) {
if(equal(pos->v, v)) {
delete v;
flag = 0;
return pos;
}
Node* ret;
if(cmp(pos->v, v)) {
if(pos->rs) ret = insert(pos->rs, v);
else {
ret = pos->rs = new Node(v);
pos->rs->fa = pos;
}
} else {
if(pos->ls) ret = insert(pos->ls, v);
else {
ret = pos->ls = new Node(v);
pos->ls->fa = pos;
}
}
pos->maintain();
maintain(pos);
return ret;
}
pair<iterator, bool> insert(value_type* const v) {
flag = 1;
return pair<iterator, bool>(iterator(this, insert(root, v)), flag);
}
void fixChain(Node* pos) {
while(pos) {
pos->maintain();
pos = pos->fa;
}
}
void erase(Node* pos) {
if(pos->ls == nullptr && pos->rs == nullptr) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = nullptr;
fixChain(pos->fa);
delete pos;
} else {
if(pos->ls == nullptr || pos->rs == nullptr) {
Node* son = pos->ls ? pos->ls : pos->rs;
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
fixChain(pos->fa);
delete pos;
} else {
Node *son = pos->ls;
while (son->rs) son = son->rs;
if(son->fa != pos) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son;
else root = son;
(son == son->fa->ls ? son->fa->ls : son->fa->rs) = pos;
std::swap(pos->ls, son->ls), std::swap(pos->rs, son->rs), std::swap(pos->fa, son->fa);
if(pos->ls) pos->ls->fa = pos; if(pos->rs) pos->rs->fa = pos;
if(son->ls) son->ls->fa = son; if(son->rs) son->rs->fa = son;
} else {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
const auto son_ls = son->ls, son_rs = son->rs;
(son == pos->ls ? son->ls : son->rs) = pos, pos->fa = son;
if((pos == son->ls ? (son->rs = pos->rs) : (son->ls = pos->ls))) (pos == son->ls ? son->rs : son->ls)->fa = son;
if((pos->ls = son_ls)) pos->ls->fa = pos;
if((pos->rs = son_rs)) pos->rs->fa = pos;
}
erase(pos);
}
}
}
Node* find(const Key* tar) const {
Node* cur = root;
while(cur) {
if(equal(cur->v, tar)) return cur;
if(cmp(cur->v, tar)) cur = cur->rs;
else cur = cur->ls;
}
return nullptr;
}
Node* findPrv(Node* pos) { // return nullptr when failed.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(Node* pos) {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
Node* findPrv(const Node* pos) const { // const version of previous two functions.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(const Node* pos) const {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
void deleteAll(Node* pos) {
if(pos == nullptr) return;
if(pos->ls) deleteAll(pos->ls);
if(pos->rs) deleteAll(pos->rs);
delete pos;
}
Node* copyAll(Node* cur) {
if(cur == nullptr) return nullptr;
Node* ret = new Node(*cur);
if(cur->ls) ret->ls = copyAll(cur->ls), ret->ls->fa = ret;
if(cur->rs) ret->rs = copyAll(cur->rs), ret->rs->fa = ret;
return ret;
}
Node* nodeBegin() const {
Node* cur = root;
while(cur->ls) cur = cur->ls;
return cur;
}
Node* nodeEnd() const {
Node* cur = root;
while(cur->rs) cur = cur->rs;
return cur;
}
public:
class iterator {
public:
map* bel;
Node* tar;
iterator(map* _bel = nullptr, Node* _tar = nullptr): bel(_bel), tar(_tar) {}
iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
value_type* operator->() const noexcept { return tar->v; }
};
class const_iterator {
public:
const map* bel;
const Node* tar;
const_iterator(const map* _bel = nullptr, const Node* _tar = nullptr): bel(_bel), tar(_tar) {}
const_iterator(const const_iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const_iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
const value_type* operator->() const noexcept { return tar->v; }
};
map() { root = new Node(); }
map(const map &other) { root = copyAll(other.root); }
map & operator=(const map &other) { if(this !=&other) deleteAll(root), root = copyAll(other.root); return *this; }
~map() { deleteAll(root); }
T & at(const Key &key) { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
const T & at(const Key &key) const { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
T & operator[](const Key &key) {
Node* tar = find(&key);
if(tar == nullptr) {
value_type* nv = new value_type(key, T());
tar = insert(nv).first.tar;
}
return tar->v->second;
}
const T & operator[](const Key &key) const { return at(key); }
iterator begin() { return iterator(this, nodeBegin()); }
const_iterator cbegin() const { return const_iterator(this, nodeBegin()); }
iterator end() { return iterator(this, nodeEnd()); }
const_iterator cend() const { return const_iterator(this, nodeEnd()); }
bool empty() const { return size() == 0; }
size_t size() const { return root->siz - 1; }
void clear() { deleteAll(root), root = new Node(); }
pair<iterator, bool> insert(const value_type &value) { value_type* nv = new value_type(value); return insert(nv); }
void erase(iterator pos) { if(pos.bel != this || pos.tar->v == nullptr) throw invalid_iterator(); else erase(pos.tar); }
size_t count(const Key &key) const { auto tar = find(&key); return tar != nullptr; }
iterator find(const Key &key) { auto tar = find(&key); return tar == nullptr ? end() : iterator(this, tar); }
const_iterator find(const Key &key) const { auto tar = find(&key); return tar == nullptr ? cend() : const_iterator(this, tar); }
};
}
#endif
不过看在第7个点只比替罪羊快0.5s的份上,估计这玩意的效率也不太行。

果然,SBT也光荣地TLE了。
我愚蠢地怀疑,评测姬迁移后,OJ变慢了。于是便去询问AT。


(为了保护助教的权益,没有截图头像框)
没错,自带大常数是萌点,绝不是黑点~想想你被卡常数过不去题百般调试却又无可奈何的样子,是不是就像RBQ一样?
(某RBQ:没错,是我了是我了)
从AT口中得知fstqwq学长去年写的也不是红黑树。查看其GitHub仓库后,发现他写的居然是——Splay!
惊了,Splay常数辣么大,怎么能过去这题?
于是复制其代码在本地进行测试,发现SBT用时4.5s的第7个测试点,Splay仅用时1.2s。
仔细分析发现,评测过半的运行时间被消耗在了第7个点上,而第7个点和第1个点只有数据范围不同,测试模式都是顺序访问。而Splay会被卡成一个单链表,每时每刻的根节点和下一个要访问的节点只有1的距离。所以最后的是:线性复杂度~
岂不妙哉?我也改Splay~
#ifndef SJTU_MAP_HPP
#define SJTU_MAP_HPP
// only for std::less<T>
#include <functional>
#include <cstddef>
#include "utility.hpp"
#include "exceptions.hpp"
/* #include <iostream>
#define debug cerr
using std::cerr;
using std::endl; */
namespace sjtu {
template<class Key, class T, class Compare = std::less<Key> >
class map {
public:
class iterator;
class const_iterator;
typedef pair<const Key, T> value_type;
private:
bool cmp(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, b->first);
}
bool equal(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return a == b;
return !Compare()(a->first, b->first) && !Compare()(b->first, a->first);
}
bool cmp(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, *b);
}
bool equal(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return 0;
return !Compare()(a->first, *b) && !Compare()(*b, a->first);
}
struct Node {
value_type* v;
Node *ls, *rs, *fa;
int siz;
Node(value_type* _v = nullptr): v(_v), ls(nullptr), rs(nullptr), fa(nullptr), siz(1) {}
Node(const Node &oth): ls(nullptr), rs(nullptr), fa(nullptr), siz(oth.siz) { v = oth.v == nullptr ? nullptr : new value_type(*oth.v); }
~Node() { delete v; }
void maintain() { siz = (ls ? ls->siz : 0) + (rs ? rs->siz : 0) + 1; }
void reset() { ls = rs = fa = nullptr, siz = 1; }
}*root;
void fixChain(Node* pos) {
while(pos) {
pos->maintain();
pos = pos->fa;
}
}
void rotate(Node* pos) {
Node* const fa = pos->fa;
if(fa->fa) (fa == fa->fa->ls ? fa->fa->ls : fa->fa->rs) = pos, pos->fa = fa->fa;
else root = pos, pos->fa = nullptr;
if(pos == fa->ls) {
fa->ls = pos->rs;
if(fa->ls) fa->ls->fa = fa;
pos->rs = fa, fa->fa = pos;
} else {
fa->rs = pos->ls;
if(fa->rs) fa->rs->fa = fa;
pos->ls = fa, fa->fa = pos;
}
fa->maintain(), pos->maintain();
}
bool gid(Node* pos) {
return pos == pos->fa->ls;
}
void splay(Node* pos) {
if(pos == nullptr) return;
fixChain(pos);
while(pos != root) {
if(pos->fa->fa == nullptr) rotate(pos);
else if(gid(pos) == gid(pos->fa)) rotate(pos->fa), rotate(pos);
else rotate(pos), rotate(pos);
}
}
pair<iterator, bool> insert(value_type* const v) {
Node* cur = root;
while(1) {
if(equal(cur->v, v)) {
delete v;
splay(cur);
return pair<iterator, bool>(iterator(this, cur), 0); // todo: return an iterator.
}
if(cmp(cur->v, v)) {
if(cur->rs) cur = cur->rs;
else {
cur->rs = new Node(v), cur->rs->fa = cur;
cur->maintain(), cur = cur->rs;
break;
}
} else {
if(cur->ls) cur = cur->ls;
else {
cur->ls = new Node(v), cur->ls->fa = cur;
cur->maintain(), cur = cur->ls;
break;
}
}
}
splay(cur);
return pair<iterator, bool>(iterator(this, cur), 1);
}
void erase(Node* pos) {
if(pos->ls == nullptr && pos->rs == nullptr) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = nullptr;
auto v = pos->fa;
delete pos; splay(v);
} else {
if(pos->ls == nullptr || pos->rs == nullptr) {
Node* son = pos->ls ? pos->ls : pos->rs;
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
auto v = pos->fa;
delete pos; splay(v);
} else {
Node *son = pos->ls;
while (son->rs) son = son->rs;
if(son->fa != pos) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son;
else root = son;
(son == son->fa->ls ? son->fa->ls : son->fa->rs) = pos;
std::swap(pos->ls, son->ls), std::swap(pos->rs, son->rs), std::swap(pos->fa, son->fa);
if(pos->ls) pos->ls->fa = pos; if(pos->rs) pos->rs->fa = pos;
if(son->ls) son->ls->fa = son; if(son->rs) son->rs->fa = son;
} else {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
const auto son_ls = son->ls, son_rs = son->rs;
(son == pos->ls ? son->ls : son->rs) = pos, pos->fa = son;
if((pos == son->ls ? (son->rs = pos->rs) : (son->ls = pos->ls))) (pos == son->ls ? son->rs : son->ls)->fa = son;
if((pos->ls = son_ls)) pos->ls->fa = pos;
if((pos->rs = son_rs)) pos->rs->fa = pos;
}
erase(pos);
}
}
}
Node* find(const Key* tar) const {
Node* cur = root;
while(cur) {
if(equal(cur->v, tar)) return cur;
if(cmp(cur->v, tar)) cur = cur->rs;
else cur = cur->ls;
}
return nullptr;
}
Node* findPrv(Node* pos) { // return nullptr when failed.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(Node* pos) {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
Node* findPrv(const Node* pos) const { // const version of previous two functions.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(const Node* pos) const {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
void deleteAll(Node* pos) {
if(pos == nullptr) return;
if(pos->ls) deleteAll(pos->ls);
if(pos->rs) deleteAll(pos->rs);
delete pos;
}
Node* copyAll(Node* cur) {
if(cur == nullptr) return nullptr;
Node* ret = new Node(*cur);
if(cur->ls) ret->ls = copyAll(cur->ls), ret->ls->fa = ret;
if(cur->rs) ret->rs = copyAll(cur->rs), ret->rs->fa = ret;
return ret;
}
Node* nodeBegin() const {
Node* cur = root;
while(cur->ls) cur = cur->ls;
return cur;
}
Node* nodeEnd() const {
Node* cur = root;
while(cur->rs) cur = cur->rs;
return cur;
}
public:
class iterator {
public:
map* bel;
Node* tar;
iterator(map* _bel = nullptr, Node* _tar = nullptr): bel(_bel), tar(_tar) {}
iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
value_type* operator->() const noexcept { return tar->v; }
};
class const_iterator {
public:
const map* bel;
const Node* tar;
const_iterator(const map* _bel = nullptr, const Node* _tar = nullptr): bel(_bel), tar(_tar) {}
const_iterator(const const_iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const_iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
const value_type* operator->() const noexcept { return tar->v; }
};
map() { root = new Node(); }
map(const map &other) { root = copyAll(other.root); }
map & operator=(const map &other) { if(this !=&other) deleteAll(root), root = copyAll(other.root); return *this; }
~map() { deleteAll(root); }
T & at(const Key &key) { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
const T & at(const Key &key) const { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
T & operator[](const Key &key) {
Node* tar = find(&key);
if(tar == nullptr) {
value_type* nv = new value_type(key, T());
tar = insert(nv).first.tar;
}
return tar->v->second;
}
const T & operator[](const Key &key) const { return at(key); }
iterator begin() { return iterator(this, nodeBegin()); }
const_iterator cbegin() const { return const_iterator(this, nodeBegin()); }
iterator end() { return iterator(this, nodeEnd()); }
const_iterator cend() const { return const_iterator(this, nodeEnd()); }
bool empty() const { return size() == 0; }
size_t size() const { return root->siz - 1; }
void clear() { deleteAll(root), root = new Node(); }
pair<iterator, bool> insert(const value_type &value) { value_type* nv = new value_type(value); return insert(nv); }
void erase(iterator pos) { if(pos.bel != this || pos.tar->v == nullptr) throw invalid_iterator(); else erase(pos.tar); }
size_t count(const Key &key) const { auto tar = find(&key); return tar != nullptr; }
iterator find(const Key &key) { auto tar = find(&key); return tar == nullptr ? end() : iterator(this, tar); }
const_iterator find(const Key &key) const { auto tar = find(&key); return tar == nullptr ? cend() : const_iterator(this, tar); }
};
}
#endif
虽然本地(macOS)测第7个点栈溢出RE,但是在远程Linux机器上ulimit -s unlimit后测试通过。
交上去:

Memory Test TLE,这不是我的问题,而是OJ的。
于是本地valgrind测试,发现如果不指定无限栈空间,则会RE。
算了,不管了,先让助教修OJ。
修好之后:

不出我所料,还真就RE了。
不就是栈溢出嘛~大不了我人工栈!
#ifndef SJTU_MAP_HPP
#define SJTU_MAP_HPP
#include <functional>
#include <cstddef>
#include "utility.hpp"
#include "exceptions.hpp"
namespace sjtu {
template<class Key, class T, class Compare = std::less<Key> >
class map {
public:
class iterator;
class const_iterator;
typedef pair<const Key, T> value_type;
private:
bool cmp(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, b->first);
}
bool equal(const value_type* a, const value_type* b) const {
if(a == nullptr || b == nullptr) return a == b;
return !Compare()(a->first, b->first) && !Compare()(b->first, a->first);
}
bool cmp(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return b == nullptr; // nullptr greater than everything.
return Compare()(a->first, *b);
}
bool equal(const value_type* a, const Key* b) const {
if(a == nullptr || b == nullptr) return 0;
return !Compare()(a->first, *b) && !Compare()(*b, a->first);
}
struct Node {
value_type* v;
Node *ls, *rs, *fa;
int siz;
Node(value_type* _v = nullptr): v(_v), ls(nullptr), rs(nullptr), fa(nullptr), siz(1) {}
Node(const Node &oth): ls(nullptr), rs(nullptr), fa(nullptr), siz(oth.siz) { v = oth.v == nullptr ? nullptr : new value_type(*oth.v); }
~Node() { delete v; }
void maintain() { siz = (ls ? ls->siz : 0) + (rs ? rs->siz : 0) + 1; }
void reset() { ls = rs = fa = nullptr, siz = 1; }
}*root;
void fixChain(Node* pos) {
while(pos) {
pos->maintain();
pos = pos->fa;
}
}
void rotate(Node* pos) {
Node* const fa = pos->fa;
if(fa->fa) (fa == fa->fa->ls ? fa->fa->ls : fa->fa->rs) = pos, pos->fa = fa->fa;
else root = pos, pos->fa = nullptr;
if(pos == fa->ls) {
fa->ls = pos->rs;
if(fa->ls) fa->ls->fa = fa;
pos->rs = fa, fa->fa = pos;
} else {
fa->rs = pos->ls;
if(fa->rs) fa->rs->fa = fa;
pos->ls = fa, fa->fa = pos;
}
fa->maintain(), pos->maintain();
}
bool gid(Node* pos) {
return pos == pos->fa->ls;
}
void splay(Node* pos) {
if(pos == nullptr) return;
fixChain(pos);
while(pos != root) {
if(pos->fa->fa == nullptr) rotate(pos);
else if(gid(pos) == gid(pos->fa)) rotate(pos->fa), rotate(pos);
else rotate(pos), rotate(pos);
}
}
pair<iterator, bool> insert(value_type* const v) {
Node* cur = root;
while(1) {
if(equal(cur->v, v)) {
delete v;
splay(cur);
return pair<iterator, bool>(iterator(this, cur), 0); // todo: return an iterator.
}
if(cmp(cur->v, v)) {
if(cur->rs) cur = cur->rs;
else {
cur->rs = new Node(v), cur->rs->fa = cur;
cur->maintain(), cur = cur->rs;
break;
}
} else {
if(cur->ls) cur = cur->ls;
else {
cur->ls = new Node(v), cur->ls->fa = cur;
cur->maintain(), cur = cur->ls;
break;
}
}
}
splay(cur);
return pair<iterator, bool>(iterator(this, cur), 1);
}
void erase(Node* pos) {
if(pos->ls == nullptr && pos->rs == nullptr) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = nullptr;
auto v = pos->fa;
delete pos; splay(v);
} else {
if(pos->ls == nullptr || pos->rs == nullptr) {
Node* son = pos->ls ? pos->ls : pos->rs;
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
auto v = pos->fa;
delete pos; splay(v);
} else {
Node *son = pos->ls;
while (son->rs) son = son->rs;
if(son->fa != pos) {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son;
else root = son;
(son == son->fa->ls ? son->fa->ls : son->fa->rs) = pos;
std::swap(pos->ls, son->ls), std::swap(pos->rs, son->rs), std::swap(pos->fa, son->fa);
if(pos->ls) pos->ls->fa = pos; if(pos->rs) pos->rs->fa = pos;
if(son->ls) son->ls->fa = son; if(son->rs) son->rs->fa = son;
} else {
if(pos->fa) (pos == pos->fa->ls ? pos->fa->ls : pos->fa->rs) = son, son->fa = pos->fa;
else root = son, son->fa = nullptr;
const auto son_ls = son->ls, son_rs = son->rs;
(son == pos->ls ? son->ls : son->rs) = pos, pos->fa = son;
if((pos == son->ls ? (son->rs = pos->rs) : (son->ls = pos->ls))) (pos == son->ls ? son->rs : son->ls)->fa = son;
if((pos->ls = son_ls)) pos->ls->fa = pos;
if((pos->rs = son_rs)) pos->rs->fa = pos;
}
erase(pos);
}
}
}
Node* find(const Key* tar) const {
Node* cur = root;
while(cur) {
if(equal(cur->v, tar)) return cur;
if(cmp(cur->v, tar)) cur = cur->rs;
else cur = cur->ls;
}
return nullptr;
}
Node* findPrv(Node* pos) { // return nullptr when failed.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(Node* pos) {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
Node* findPrv(const Node* pos) const { // const version of previous two functions.
if(pos == nullptr) return nullptr;
if(pos->ls) {
Node* ret = pos->ls;
while(ret->rs) ret = ret -> rs;
return ret;
}
while(pos->fa && pos == pos->fa->ls) pos = pos->fa;
return pos->fa;
}
Node* findNxt(const Node* pos) const {
if(pos == nullptr) return nullptr;
if(pos->rs) {
Node* ret = pos->rs;
while(ret->ls) ret = ret -> ls;
return ret;
}
while(pos->fa && pos == pos->fa->rs) pos = pos->fa;
return pos->fa;
}
void deleteAll(Node* _pos) {
if(_pos == nullptr) return;
const int fs = _pos->siz;
Node** stk = new Node*[fs + 2];
int* step = new int[fs + 2], top = 0;
stk[++top] = _pos, step[top] = 0;
while(top) {
Node* const pos = stk[top];
const int ss = step[top--];
if(ss == 0) {
stk[++top] = pos, step[top] = 1;
if(pos->ls) stk[++top] = pos->ls, step[top] = 0;
} else if(ss == 1) {
stk[++top] = pos, step[top] = 2;
if(pos->rs) stk[++top] = pos->rs, step[top] = 0;
} else delete pos;
}
delete[] stk;
delete[] step;
}
Node* copyAll(Node* cur) {
if(cur == nullptr) return nullptr;
Node* ret = new Node(*cur);
if(cur->ls) ret->ls = copyAll(cur->ls), ret->ls->fa = ret;
if(cur->rs) ret->rs = copyAll(cur->rs), ret->rs->fa = ret;
return ret;
}
Node* nodeBegin() const {
Node* cur = root;
while(cur->ls) cur = cur->ls;
return cur;
}
Node* nodeEnd() const {
Node* cur = root;
while(cur->rs) cur = cur->rs;
return cur;
}
public:
class iterator {
public:
map* bel;
Node* tar;
iterator(map* _bel = nullptr, Node* _tar = nullptr): bel(_bel), tar(_tar) {}
iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
value_type* operator->() const noexcept { return tar->v; }
};
class const_iterator {
public:
const map* bel;
const Node* tar;
const_iterator(const map* _bel = nullptr, const Node* _tar = nullptr): bel(_bel), tar(_tar) {}
const_iterator(const const_iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator(const iterator &other):bel(other.bel), tar(other.tar) {}
const_iterator operator++(int) { auto ret = *this; tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator++() { tar = bel->findNxt(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const_iterator operator--(int) { auto ret = *this; tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return ret; }
const_iterator & operator--() { tar = bel->findPrv(tar); if(tar == nullptr) throw invalid_iterator(); else return *this; }
const value_type & operator*() const { return *tar->v; }
bool operator==(const iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator==(const const_iterator &rhs) const { return bel == rhs.bel && tar == rhs.tar; }
bool operator!=(const iterator &rhs) const { return !(*this == rhs);}
bool operator!=(const const_iterator &rhs) const { return !(*this == rhs); }
const value_type* operator->() const noexcept { return tar->v; }
};
map() { root = new Node(); }
map(const map &other) { root = copyAll(other.root); }
map & operator=(const map &other) { if(this !=&other) deleteAll(root), root = copyAll(other.root); return *this; }
~map() { deleteAll(root); }
T & at(const Key &key) { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
const T & at(const Key &key) const { Node* tar = find(&key); if(tar == nullptr) throw index_out_of_bound(); return tar->v->second; }
T & operator[](const Key &key) {
Node* tar = find(&key);
if(tar == nullptr) {
value_type* nv = new value_type(key, T());
tar = insert(nv).first.tar;
}
return tar->v->second;
}
const T & operator[](const Key &key) const { return at(key); }
iterator begin() { return iterator(this, nodeBegin()); }
const_iterator cbegin() const { return const_iterator(this, nodeBegin()); }
iterator end() { return iterator(this, nodeEnd()); }
const_iterator cend() const { return const_iterator(this, nodeEnd()); }
bool empty() const { return size() == 0; }
size_t size() const { return root->siz - 1; }
void clear() { deleteAll(root), root = new Node(); }
pair<iterator, bool> insert(const value_type &value) { value_type* nv = new value_type(value); return insert(nv); }
void erase(iterator pos) { if(pos.bel != this || pos.tar->v == nullptr) throw invalid_iterator(); else erase(pos.tar); }
size_t count(const Key &key) const { auto tar = find(&key); return tar != nullptr; }
iterator find(const Key &key) { auto tar = find(&key); return tar == nullptr ? end() : iterator(this, tar); }
const_iterator find(const Key &key) const { auto tar = find(&key); return tar == nullptr ? cend() : const_iterator(this, tar); }
};
}
#endif
改完了,送~

什么?人工栈A了。
于是,这题就被我用各种奇奇怪怪的方法卡了过去。
(绝对不能让AT看到这篇博文,否则RBQ这个称呼怕是得成真~)
好的,就酱,2020年春季学期数据结构第一个大作业宣告完成。DDL为第9周星期四,现在是第3周星期三,提前了6周的时间。
代码开源地址:https://github.com/cmd2001/STLite-2020
最后,火车票大作业求组队啊qwq
求各位大佬们带带我qwq
保证认真担当开发组吉祥物,摸鱼划水整活卖萌样样精通~
我是Amagi_Yukisaki,撕裂时空之人。如果想破碎世界连续性的话,请最好离我近一些。以上。




评论~ NOTHING