恋恋风辰的个人博客


  • Home

  • Archives

  • Categories

  • Tags

  • Search

零基础C++(33) 单例模式演变

Posted on 2025-03-01 | In 零基础C++

单例模式概念

今天给大家讲讲单例模式演变流程,从C++98到C++11经历了哪些变化?哪一种单例模式更为安全。

单例模式(Singleton Pattern)是一种创建型设计模式,旨在确保一个类在整个应用程序生命周期中仅且只有一个实例,并提供一个全局访问点以获取该实例。设计单例模式的主要原因和作用包括以下几个方面:

1. 控制实例数量

单例模式确保一个类只有一个实例,防止在程序中创建多个实例可能导致的资源浪费或状态不一致问题。例如,数据库连接池、配置管理类等通常使用单例模式,以确保全局只有一个连接池或配置实例。

2. 提供全局访问点

单例模式通过提供一个全局访问点,使得在程序的任何地方都可以方便地访问该实例。这对于那些需要在多个模块或组件之间共享的资源或服务尤为重要,如日志记录器、缓存管理器等。

3. 延迟实例化

单例模式通常采用懒加载的方式,即在第一次需要使用实例时才创建。这有助于节省系统资源,特别是在实例创建成本较高或初期并不需要该实例的情况下。

4. 避免命名冲突

通过将单例实例作为一个类的静态成员,可以避免在全局命名空间中引入多个实例,减少命名冲突的风险。

5. 管理共享资源

在多线程环境下,单例模式可以有效管理共享资源,确保线程安全。例如,操作系统中的线程池、任务管理器等常使用单例模式,以协调多个线程对资源的访问。

设计单例模式的考虑因素

虽然单例模式有诸多优点,但在设计和使用时也需要注意以下几点:

  • 线程安全:在多线程环境下,需要确保单例实例的创建和访问是线程安全的,常用的方法有双重检查锁定(Double-Checked Locking)和使用静态内部类等。
  • 延迟初始化:根据需求选择是否采用延迟初始化,以平衡性能和资源利用。
  • 可测试性:单例模式可能会影响代码的可测试性,特别是在单元测试中,可能需要通过依赖注入等手段来替代单例实例。
  • 限制扩展:单例模式通过限制实例数量可能会限制类的扩展性,需谨慎使用。

适用场景

  • 需要确保全局只有一个实例的场景,如配置管理、日志系统、设备驱动等。
  • 需要全局访问点来协调系统中的多个部分,如缓存、线程池等。

不适用场景

  • 需要多个实例以满足不同需求的场景。
  • 对象的生命周期需要更灵活控制的场合。

总的来说,单例模式通过控制类的实例数量和提供全局访问点,为系统资源管理和状态一致性提供了有效的解决方案。然而,在实际应用中,应根据具体需求和上下文环境,谨慎决定是否使用单例模式,以避免潜在的设计问题。

局部静态变量方式

这种方式最简单

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//通过静态成员变量实现单例
//懒汉式
class Single2
{
private:
Single2()
{
}
Single2(const Single2 &) = delete;
Single2 &operator=(const Single2 &) = delete;
public:
static Single2 &GetInst()
{
static Single2 single;
return single;
}
};

上述代码通过局部静态成员single实现单例类,原理就是函数的局部静态变量生命周期随着进程结束而结束。上述代码通过懒汉式的方式实现。
调用如下

1
2
3
4
5
6
void test_single2()
{
//多线程情况下可能存在问题
cout << "s1 addr is " << &Single2::GetInst() << endl;
cout << "s2 addr is " << &Single2::GetInst() << endl;
}

程序输出如下

1
2
sp1  is  0x1304b10
sp2 is 0x1304b10

确实生成了唯一实例,在C++98年代,上述单例模式存在隐患,对于多线程方式生成的实例可能时多个。

随着C++ 11的来临,这种方式不再存在线程安全问题,是最为简单也是最适合新手的方式。

静态成员变量指针方式(饿汉式)

可以定义一个类的静态成员变量,用来控制实现单例,这种方式依靠静态成员提前初始化保证生成的单例是唯一的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
//饿汉式
class Single2Hungry
{
private:
Single2Hungry()
{
}
Single2Hungry(const Single2Hungry &) = delete;
Single2Hungry &operator=(const Single2Hungry &) = delete;
public:
static Single2Hungry *GetInst()
{
if (single == nullptr)
{
single = new Single2Hungry();
}
return single;
}
private:
static Single2Hungry *single;
};

这么做的一个好处是我们可以通过饿汉式的方式避免线程安全问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//饿汉式初始化
Single2Hungry *Single2Hungry::single = Single2Hungry::GetInst();
void thread_func_s2(int i)
{
cout << "this is thread " << i << endl;
cout << "inst is " << Single2Hungry::GetInst() << endl;
}
void test_single2hungry()
{
cout << "s1 addr is " << Single2Hungry::GetInst() << endl;
cout << "s2 addr is " << Single2Hungry::GetInst() << endl;
for (int i = 0; i < 3; i++)
{
thread tid(thread_func_s2, i);
tid.join();
}
}
int main(){
test_single2hungry()
}

程序输出如下

1
2
3
4
5
6
7
8
s1 addr is 0x1e4b00
s2 addr is 0x1e4b00
this is thread 0
inst is 0x1e4b00
this is thread 1
inst is 0x1e4b00
this is thread 2
inst is 0x1e4b00

可见无论单线程还是多线程模式下,通过静态成员变量的指针实现的单例类都是唯一的。

饿汉式是在程序启动时就进行单例的初始化,这种方式也可以通过懒汉式调用,无论饿汉式还是懒汉式都存在一个问题,就是什么时候释放内存?

多线程情况下,释放内存就很难了,还有二次释放内存的风险。

静态成员变量指针方式(懒汉式)

我们定义一个单例类并用懒汉式方式调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
//懒汉式指针
//即使创建指针类型也存在问题
class SinglePointer
{
private:
SinglePointer()
{
}
SinglePointer(const SinglePointer &) = delete;
SinglePointer &operator=(const SinglePointer &) = delete;
public:
static SinglePointer *GetInst()
{
if (single != nullptr)
{
return single;
}
s_mutex.lock();
if (single != nullptr)
{
s_mutex.unlock();
return single;
}
single = new SinglePointer();
s_mutex.unlock();
return single;
}
private:
static SinglePointer *single;
static mutex s_mutex;
};

在cpp文件里初始化静态成员,并定义一个测试函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
//懒汉式
//在类的cpp文件定义static变量
SinglePointer *SinglePointer::single = nullptr;
std::mutex SinglePointer::s_mutex;
void thread_func_lazy(int i)
{
cout << "this is lazy thread " << i << endl;
cout << "inst is " << SinglePointer::GetInst() << endl;
}
void test_singlelazy()
{
for (int i = 0; i < 3; i++)
{
thread tid(thread_func_lazy, i);
tid.join();
}
//何时释放new的对象?造成内存泄漏
}
int main(){
test_singlelazy();
}

函数输出如下

1
2
3
4
5
6
this is lazy thread 0
inst is 0xbc1700
this is lazy thread 1
inst is 0xbc1700
this is lazy thread 2
inst is 0xbc1700

此时生成的单例对象的内存空间还没回收,这是个问题,另外如果多线程情况下多次delete也会造成崩溃。

C++11改进

我们可以利用C++11 提供的once_flag实现安全的创建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef DAY34_SINGLETON_SINGLETON_H
#define DAY34_SINGLETON_SINGLETON_H

#include <mutex>
#include <iostream>

class SingletonOnceFlag{
public:
static SingletonOnceFlag* getInstance(){
static std::once_flag flag;
std::call_once(flag, []{
_instance = new SingletonOnceFlag();
});
return _instance;
}

void PrintAddress() {
std::cout << _instance << std::endl;
}
~SingletonOnceFlag() {
std::cout << "this is singleton destruct" << std::endl;
}
private:
SingletonOnceFlag() = default;
SingletonOnceFlag(const SingletonOnceFlag&) = delete;
SingletonOnceFlag& operator=(const SingletonOnceFlag& st) = delete;
static SingletonOnceFlag* _instance;

};


#endif //DAY34_SINGLETON_SINGLETON_H

static成员要在cpp中初始化

1
SingletonOnceFlag *SingletonOnceFlag::_instance = nullptr;

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "Singleton.h"
#include <thread>
#include <mutex>
int main() {
system("chcp 65001 > nul");
std::mutex mtx;
std::thread t1([&](){
SingletonOnceFlag::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingletonOnceFlag::getInstance()->PrintAddress();
});

std::thread t2([&](){
SingletonOnceFlag::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingletonOnceFlag::getInstance()->PrintAddress();
});

t1.join();
t2.join();

return 0;
}

测试结果

1
2
0x19a74de7420
0x19a74de7420

智能指针方式(懒汉式)

可以利用智能指针自动回收内存的机制设计单例类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#ifndef DAY34_SINGLETON_SINGLETON_H
#define DAY34_SINGLETON_SINGLETON_H

#include <mutex>
#include <iostream>
#include <memory>

class SingletonOnceFlag{
public:
static std::shared_ptr<SingletonOnceFlag> getInstance(){
static std::once_flag flag;
std::call_once(flag, []{
_instance = std::shared_ptr<SingletonOnceFlag>(new SingletonOnceFlag());
});
return _instance;
}

void PrintAddress() {
std::cout << _instance << std::endl;
}
~SingletonOnceFlag() {
std::cout << "this is singleton destruct" << std::endl;
}
private:
SingletonOnceFlag() = default;
SingletonOnceFlag(const SingletonOnceFlag&) = delete;
SingletonOnceFlag& operator=(const SingletonOnceFlag& st) = delete;
static std::shared_ptr<SingletonOnceFlag> _instance;

};


#endif //DAY34_SINGLETON_SINGLETON_H

同样在SingletonOnceFlag.cpp中进行单例成员的初始化

1
2
3
#include "Singleton.h"

std::shared_ptr<SingletonOnceFlag> SingletonOnceFlag::_instance = nullptr;

再次测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <iostream>
#include "Singleton.h"
#include <thread>
#include <mutex>
int main() {
system("chcp 65001 > nul");
std::mutex mtx;
std::thread t1([&](){
SingletonOnceFlag::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingletonOnceFlag::getInstance()->PrintAddress();
});

std::thread t2([&](){
SingletonOnceFlag::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingletonOnceFlag::getInstance()->PrintAddress();
});

t1.join();
t2.join();

return 0;
}

这次输出析构信息

1
2
3
0x1d620a47420
0x1d620a47420
this is singleton destruct

辅助类智能指针单例模式

智能指针在构造的时候可以指定删除器,所以可以传递一个辅助类或者辅助函数帮助智能指针回收内存时调用我们指定的析构函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
//
// Created by secon on 2025/3/1.
//

#ifndef DAY34_SINGLETON_SINGLETON_H
#define DAY34_SINGLETON_SINGLETON_H

#include <mutex>
#include <iostream>
#include <memory>

class SingleAutoSafe;
class SafeDeletor
{
public:
void operator()(SingleAutoSafe *sf)
{
std::cout << "this is safe deleter operator()" << std::endl;
delete sf;
}
};

class SingleAutoSafe{
public:
static std::shared_ptr<SingleAutoSafe> getInstance(){
static std::once_flag flag;
std::call_once(flag, []{
_instance = std::shared_ptr<SingleAutoSafe>(new SingleAutoSafe(), SafeDeletor());
});
return _instance;
}

void PrintAddress() {
std::cout << _instance << std::endl;
}
//定义友元类,通过友元类调用该类析构函数
friend class SafeDeletor;
private:
SingleAutoSafe() = default;
SingleAutoSafe(const SingleAutoSafe&) = delete;
SingleAutoSafe& operator=(const SingleAutoSafe& st) = delete;
~SingleAutoSafe() {
std::cout << "this is singleton destruct" << std::endl;
}
static std::shared_ptr<SingleAutoSafe> _instance;

};


#endif //DAY34_SINGLETON_SINGLETON_H

在cpp文件中实现静态成员的定义

1
2
3
4

#include "Singleton.h"

std::shared_ptr<SingleAutoSafe> SingleAutoSafe::_instance = nullptr;

SafeDeletor要写在SingleAutoSafe上边,并且SafeDeletor要声明为SingleAutoSafe类的友元类,这样就可以访问SingleAutoSafe的析构函数了。

我们在构造single时制定了SafeDeletor(),single在回收时,会调用SingleAutoSafe的仿函数,从而完成内存的销毁。

并且SingleAutoSafe的析构函数为私有的无法被外界手动调用了。

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <iostream>
#include "Singleton.h"
#include <thread>
#include <mutex>
int main() {
system("chcp 65001 > nul");
std::mutex mtx;
std::thread t1([&](){
SingleAutoSafe::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingleAutoSafe::getInstance()->PrintAddress();
});

std::thread t2([&](){
SingleAutoSafe::getInstance();
std::lock_guard<std::mutex> lock(mtx);
SingleAutoSafe::getInstance()->PrintAddress();
});

t1.join();
t2.join();

return 0;
}

程序输出

1
2
3
0x1b379b07420
0x1b379b07420
this is safe deleter operator()

通用的单例模板类(CRTP)

我们可以通过声明单例的模板类,然后继承这个单例模板类的所有类就是单例类了。达到泛型编程提高效率的目的。

CRTP的概念

CRTP是一种将派生类作为模板参数传递给基类的技术,即一个类继承自一个以自身为模板参数的基类。这种模式常用于实现静态多态、接口的默认实现、编译时策略选择等。

比如

1
2
3
4
5
6
7
8
template <typename T>
class TempClass {
//...
};
//CRTP
class RealClass: public TempClass<RealClass>{
//...
};

单例基类实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <memory>
#include <mutex>
#include <iostream>
using namespace std;
template <typename T>
class Singleton {
protected:
Singleton() = default;
Singleton(const Singleton<T>&) = delete;
Singleton& operator=(const Singleton<T>& st) = delete;

static std::shared_ptr<T> _instance;
public:
static std::shared_ptr<T> GetInstance() {
static std::once_flag s_flag;
std::call_once(s_flag, [&]() {
_instance = shared_ptr<T>(new T);
});

return _instance;
}
void PrintAddress() {
std::cout << _instance.get() << endl;
}
~Singleton() {
std::cout << "this is singleton destruct" << std::endl;
}
};

template <typename T>
std::shared_ptr<T> Singleton<T>::_instance = nullptr;

我们定义一个网络的单例类,继承上述模板类即可,并将构造和析构设置为私有,同时设置友元保证自己的析构和构造可以被友元类调用.

1
2
3
4
5
6
7
8
9
10
11
//通过继承方式实现网络模块单例
class SingleNet : public Singleton<SingleNet>
{
friend class Singleton<SingleNet>;
private:
SingleNet() = default;
public:
~SingleNet() {
std::cout << "SingleNet destruct " << std::endl;
}
};

测试案例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <iostream>
#include "Singleton.h"
#include <thread>
#include <mutex>
int main() {
system("chcp 65001 > nul");
std::mutex mtx;
std::thread t1([&](){
SingleNet::GetInstance();
std::lock_guard<std::mutex> lock(mtx);
SingleNet::GetInstance()->PrintAddress();
});

std::thread t2([&](){
SingleNet::GetInstance();
std::lock_guard<std::mutex> lock(mtx);
SingleNet::GetInstance()->PrintAddress();
});

t1.join();
t2.join();

return 0;
}

程序输出

1
2
3
4
0x212248b7420
0x212248b7420
SingleNet destruct
this is singleton destruct

源码和视频

视频地址

源码地址

引用折叠和原样转发

Posted on 2025-02-22 | In 零基础C++

1. 左值与右值

1.1 定义与分类

左值(lvalue)和右值(rvalue)是C++中用于描述表达式值类别的重要概念。

  • 左值(lvalue):
    • 表示具有持久存储的对象。
    • 可以出现在赋值语句的左侧。
    • 可以被取地址(即,可以使用&运算符)。
    • 示例:变量名、引用等。
  • 右值(rvalue):
    • 表示临时对象或没有持久存储的值。
    • 通常出现在赋值语句的右侧。
    • 不能被取地址。
    • 示例:字面量、临时对象、表达式结果等。

C++11进一步细化了右值的分类:

  • 纯右值(prvalues):表示临时对象或字面量,如42、3.14。
  • 将亡值(xvalues,expiring values):表示即将被移动的对象,如std::move的结果。
Read more »

DeepSeek-R1本地部署知识库

Posted on 2025-02-08 | In AI

视频教程

考虑大家看文档会比较吃力,可以参考我的视频教程

https://www.bilibili.com/video/BV1GkNyeDEpK/?pop_share=1&vd_source=8be9e83424c2ed2c9b2a3ed1d01385e9

DeepSeek简介

深度求索人工智能基础技术研究有限公司(简称“深度求索”或“DeepSeek”),成立于2023年,是一家专注于实现AGI的中国公司。

产品DeepSeek为一款AI工具,可以解析文本,生成代码,推理解析等。

模型 & 价格

下表所列模型价格以“百万 tokens”为单位。Token 是模型用来表示自然语言文本的的最小单位,可以是一个词、一个数字或一个标点符号等。我们将根据模型输入和输出的总 token 数进行计量计费。

模型(1) 上下文长度 最大思维链长度(2) 最大输出长度(3) 百万tokens 输入价格 (缓存命中)(4) 百万tokens 输入价格 (缓存未命中) 百万tokens 输出价格 输出价格
deepseek-chat 64K - 8K 0.5元 2元 8元
deepseek-reasoner 64K 32K 8K 1元 4元 16元(5)
  1. deepseek-chat 模型已经升级为 DeepSeek-V3;deepseek-reasoner 模型为新模型 DeepSeek-R1。
  2. 思维链为deepseek-reasoner模型在给出正式回答之前的思考过程,其原理详见推理模型。
  3. 如未指定 max_tokens,默认最大输出长度为 4K。请调整 max_tokens 以支持更长的输出。
  4. 关于上下文缓存的细节,请参考DeepSeek 硬盘缓存。
  5. deepseek-reasoner的输出 token 数包含了思维链和最终答案的所有 token,其计价相同。

性能对齐 OpenAI-o1 正式版

DeepSeek-R1 在后训练阶段大规模使用了强化学习技术,在仅有极少标注数据的情况下,极大提升了模型推理能力。在数学、代码、自然语言推理等任务上,性能比肩 OpenAI o1 正式版。

https://cdn.llfc.club/abbe5dcf3b85c33ded2d624fe84a2ec.png

蒸馏小模型超越 OpenAI o1-mini

在开源 DeepSeek-R1-Zero 和 DeepSeek-R1 两个 660B 模型的同时,通过 DeepSeek-R1 的输出,蒸馏了 6 个小模型开源给社区,其中 32B 和 70B 模型在多项能力上实现了对标 OpenAI o1-mini 的效果。

img

API 及定价

DeepSeek-R1 API 服务定价为每百万输入 tokens 1 元(缓存命中)/ 4 元(缓存未命中),每百万输出 tokens 16 元。

img

img

详细的 API 调用指南请参考官方文档: https://api-docs.deepseek.com/zh-cn/guides/reasoning_model

本地部署

Ollama 官方版:【点击前往】

image-20250209105503303

下载windows版本并安装

image-20250209111710518

我在windows做测试,然后部署

安装命令

1.5B Qwen DeepSeek R1(需4G显存)

1
ollama run deepseek-r1:1.5b

7B Qwen DeepSeek R1(4~12G显存)

1
ollama run deepseek-r1:7b

8B Llama DeepSeek R1(4~12G显存)

1
ollama run deepseek-r1:8b

14B Qwen DeepSeek R1(12~24G显存)

1
ollama run deepseek-r1:14b

32B Qwen DeepSeek R1(24G显存)

1
ollama run deepseek-r1:32b

70B Llama DeepSeek R1(32G显存以上)

1
ollama run deepseek-r1:70b

我们选择1.5b模型安装

image-20250209112145406

安装成功image-20250209114003740

AnythingLLM 下载

image-20250209114426088

1、Github 开源版 【点击下载】

2、官方版:【点击下载】

我们选择开源版下载,拉到最下面有markdown文档

image-20250209114631891

点击DownloadNow, 选择x64版本下载

image-20250209115524398

下载完成,安装

image-20250209115436788

安装完成

image-20250209121534198

点击完成自动运行

image-20250209121713655

点击Get started,然后一路跳过

image-20250209121816605

记得填写邮箱和使用目的

然后再填写工作区

image-20250209122112515

点击右侧箭头完成,进入使用界面,如果此时使用会出现模型未设置

image-20250209122226640

点击设置按钮,选择聊天选项,接下来选择模型

image-20250209122457152

设置模型为ollama

image-20250209122532873

设置后

image-20250209122828156

记得更新

image-20250209122859291

注意:本地部署也是支持开启联网搜索模式的

点击设置按钮

image-20250209123038941

只需在AnythingLLM的设置界面中,找到“代理技能”选项。 启用Web Search:在代理技能列表中找到 Web Search ,点击开启。 选择搜索引擎即可!

如下图所示:

image-20250209123125101

测试

测试代码生成能力

image-20250209124009974

测试检索能力

image-20250209124053537

测试推理能力

image-20250209124127054

模板详解

Posted on 2025-01-29 | In 零基础C++

模板基础

C++ 模板(Templates)是现代 C++ 中强大而灵活的特性,支持泛型编程,使得代码更具复用性和类型安全性。模板不仅包括基本的函数模板和类模板,还涵盖了模板特化(全特化与偏特化)、模板参数种类、变参模板(Variadic Templates)、模板元编程(Template Metaprogramming)、SFINAE(Substitution Failure Is Not An Error)等高级内容。

函数模板

函数模板允许编写通用的函数,通过类型参数化,使其能够处理不同的数据类型。它通过模板参数定义与类型无关的函数。

语法:

1
2
3
4
template <typename T>
T functionName(T param) {
// 函数体
}

示例:最大值函数

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

template <typename T>
T maxValue(T a, T b) {
return (a > b) ? a : b;
}

int main() {
std::cout << maxValue(3, 7) << std::endl; // int 类型
std::cout << maxValue(3.14, 2.72) << std::endl; // double 类型
std::cout << maxValue('a', 'z') << std::endl; // char 类型
return 0;
}

输出:

1
2
3
7
3.14
z

要点:

  • 模板参数列表以 template <typename T> 或 template <class T> 开头,两者等价。
  • 类型推导:编译器根据函数参数自动推导模板参数类型。

类模板

类模板允许定义通用的类,通过类型参数化,实现不同类型的对象。

语法:

1
2
3
4
5
6
template <typename T>
class ClassName {
public:
T memberVariable;
// 构造函数、成员函数等
};

示例:简单的 Pair 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <iostream>
#include <string>

template <typename T, typename U>
class Pair {
public:
T first;
U second;

Pair(T a, U b) : first(a), second(b) {}

void print() const {
std::cout << "Pair: " << first << ", " << second << std::endl;
}
};

int main() {
Pair<int, double> p1(1, 2.5);
p1.print(); // 输出:Pair: 1, 2.5

Pair<std::string, std::string> p2("Hello", "World");
p2.print(); // 输出:Pair: Hello, World

Pair<std::string, int> p3("Age", 30);
p3.print(); // 输出:Pair: Age, 30

return 0;
}

输出:

1
2
3
Pair: 1, 2.5
Pair: Hello, World
Pair: Age, 30

要点:

  • 类模板可以有多个类型参数。
  • 模板参数可以被用于成员变量和成员函数中。
  • 类模板实例化时指定具体类型,如 Pair<int, double>。

模板参数

模板参数决定了模板的泛用性与灵活性。C++ 模板参数种类主要包括类型参数、非类型参数和模板模板参数。

类型参数(Type Parameters)

类型参数用于表示任意类型,在模板实例化时被具体的类型替代。

示例:

1
2
3
4
5
template <typename T>
class MyClass {
public:
T data;
};

非类型参数(Non-Type Parameters)

非类型参数允许模板接受非类型的值,如整数、指针或引用。C++17 支持更多非类型参数类型,如 auto。

语法:

1
2
3
4
5
template <typename T, int N>
class FixedArray {
public:
T data[N];
};

示例:固定大小的数组类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <iostream>

template <typename T, std::size_t N>
class FixedArray {
public:
T data[N];

T& operator[](std::size_t index) {
return data[index];
}

void print() const {
for(std::size_t i = 0; i < N; ++i)
std::cout << data[i] << " ";
std::cout << std::endl;
}
};

int main() {
FixedArray<int, 5> arr;
for(int i = 0; i < 5; ++i)
arr[i] = i * 10;
arr.print(); // 输出:0 10 20 30 40
return 0;
}

输出:

1
0 10 20 30 40 

注意事项:

  • 非类型参数必须是编译期常量。
  • 允许的类型包括整型、枚举、指针、引用等,但不包括浮点数和类类型。

模板模板参数(Template Template Parameters)

模板模板参数允许模板接受另一个模板作为参数。这对于抽象容器和策略模式等场景非常有用。

语法:

1
2
template <template <typename, typename> class Container>
class MyClass { /* ... */ };

示例:容器适配器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <iostream>
#include <vector>
#include <list>

template <template <typename, typename> class Container, typename T>
class ContainerPrinter {
public:
void print(const Container<T, std::allocator<T>>& container) {
for(const auto& elem : container)
std::cout << elem << " ";
std::cout << std::endl;
}
};

int main() {
std::vector<int> vec = {1, 2, 3, 4, 5};
std::list<int> lst = {10, 20, 30};

ContainerPrinter<std::vector, int> vecPrinter;
vecPrinter.print(vec); // 输出:1 2 3 4 5

ContainerPrinter<std::list, int> listPrinter;
listPrinter.print(lst); // 输出:10 20 30

return 0;
}

输出:

1
2
1 2 3 4 5 
10 20 30

要点:

  • 模板模板参数需要完全匹配被接受模板的参数列表。
  • 可通过默认模板参数增强灵活性。

模板特化(Template Specialization)

模板特化允许开发者为特定类型或类型组合提供专门的实现。当通用模板无法满足特定需求时,特化模板可以调整行为以处理特定的情况。C++ 支持全特化(Full Specialization)**和**偏特化(Partial Specialization)**,但需要注意的是,函数模板不支持偏特化**,只能进行全特化。

全特化(Full Specialization)

全特化是针对模板参数的完全特定类型组合。它提供了模板的一个特定版本,当模板参数完全匹配特化类型时,编译器将优先使用该特化版本。

语法

1
2
3
4
5
6
7
8
9
10
11
// 通用模板
template <typename T>
class MyClass {
// 通用实现
};

// 全特化
template <>
class MyClass<SpecificType> {
// 针对 SpecificType 的实现
};

示例:类模板全特化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <iostream>
#include <string>

// 通用类模板
template <typename T>
class Printer {
public:
void print(const T& value) {
std::cout << "General Printer: " << value << std::endl;
}
};

// 类模板全特化
template <>
class Printer<std::string> {
public:
void print(const std::string& value) {
std::cout << "String Printer: " << value << std::endl;
}
};

int main() {
Printer<int> intPrinter;
intPrinter.print(100); // 输出:General Printer: 100

Printer<std::string> stringPrinter;
stringPrinter.print("Hello, World!"); // 输出:String Printer: Hello, World!

return 0;
}

输出:

1
2
General Printer: 100
String Printer: Hello, World!

解析

  • 通用模板适用于所有类型,在print函数中以通用方式输出值。
  • 全特化模板针对std::string类型进行了专门化,实现了不同的print函数。
  • 当实例化Printer<std::string>时,编译器选择全特化版本而非通用模板。

偏特化(Partial Specialization)

偏特化允许模板对部分参数进行特定类型的处理,同时保持其他参数的通用性。对于类模板而言,可以针对模板参数的某些特性进行偏特化;对于函数模板,则仅支持全特化,不支持偏特化。

语法

1
2
3
4
5
6
7
8
9
10
11
// 通用模板
template <typename T, typename U>
class MyClass {
// 通用实现
};

// 偏特化:当 U 是指针类型时
template <typename T, typename U>
class MyClass<T, U*> {
// 针对 U* 的实现
};

示例:类模板偏特化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <iostream>
#include <string>

// 通用 Pair 类模板
template <typename T, typename U>
class Pair {
public:
T first;
U second;

Pair(T a, U b) : first(a), second(b) {}

void print() const {
std::cout << "Pair: " << first << ", " << second << std::endl;
}
};

// 类模板偏特化:当第二个类型是指针时
template <typename T, typename U>
class Pair<T, U*> {
public:
T first;
U* second;

Pair(T a, U* b) : first(a), second(b) {}

void print() const {
std::cout << "Pair with pointer: " << first << ", " << *second << std::endl;
}
};

int main() {
Pair<int, double> p1(1, 2.5);
p1.print(); // 输出:Pair: 1, 2.5

double value = 3.14;
Pair<std::string, double*> p2("Pi", &value);
p2.print(); // 输出:Pair with pointer: Pi, 3.14

return 0;
}

输出:

1
2
Pair: 1, 2.5
Pair with pointer: Pi, 3.14

解析

  • 通用模板处理非指针类型对。
  • 偏特化模板处理第二个类型为指针的情况,打印指针指向的值。
  • 使用偏特化提升了模板的灵活性,使其能够根据部分参数类型进行不同处理。

函数模板的特化

与类模板不同,函数模板不支持偏特化,只能进行全特化。当对函数模板进行全特化时,需要显式指定类型。

示例:函数模板全特化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <iostream>
#include <string>

// 通用函数模板
template <typename T>
void printValue(const T& value) {
std::cout << "General print: " << value << std::endl;
}

// 函数模板全特化
template <>
void printValue<std::string>(const std::string& value) {
std::cout << "Specialized print for std::string: " << value << std::endl;
}

int main() {
printValue(42); // 调用通用模板,输出:General print: 42
printValue(3.14); // 调用通用模板,输出:General print: 3.14
printValue(std::string("Hello")); // 调用全特化模板,输出:Specialized print for std::string: Hello
return 0;
}

输出:

1
2
3
General print: 42
General print: 3.14
Specialized print for std::string: Hello

解析

  • 通用函数模板适用于所有类型,提供通用的printValue实现。
  • 全特化函数模板专门处理std::string类型,提供不同的输出格式。
  • 调用printValue时,编译器根据实参类型选择适当的模板版本。

注意事项

  • 优先级:全特化版本的优先级高于通用模板,因此当特化条件满足时,总是选择特化版本。
  • 显式指定类型:函数模板特化需要在调用时显式指定类型,或者确保类型推导可以正确匹配特化版本。
  • 不支持偏特化:无法通过偏特化对函数模板进行部分特化,需要通过其他方法(如重载)实现类似功能。

总结

  • 全特化适用于为具体类型或类型组合提供专门实现,适用于类模板和函数模板。
  • 偏特化仅适用于类模板,允许针对部分参数进行特定处理,同时保持其他参数的通用性。
  • 函数模板仅支持全特化,不支持偏特化;类模板支持全特化和偏特化。
  • 特化模板提升了模板的灵活性和适应性,使其能够根据不同类型需求调整行为。

变参模板(Variadic Templates)

变参模板允许模板接受可变数量的参数,提供极高的灵活性,是实现诸如 std::tuple、std::variant 等模板库组件的基础。

定义与语法

变参模板使用 参数包(Parameter Pack),通过 ... 语法来表示。

语法:

1
2
3
4
5
template <typename... Args>
class MyClass { /* ... */ };

template <typename T, typename... Args>
void myFunction(T first, Args... args) { /* ... */ }

递归与展开(Recursion and Expansion)

变参模板通常与递归相结合,通过递归地处理参数包,或者使用 折叠表达式(Fold Expressions) 来展开发参数包。

递归示例:打印所有参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <iostream>

// 基础情况:无参数
void printAll() {
std::cout << std::endl;
}

// 递归情况:至少一个参数
template <typename T, typename... Args>
void printAll(const T& first, const Args&... args) {
std::cout << first << " ";
printAll(args...); // 递归调用
}

int main() {
printAll(1, 2.5, "Hello", 'A'); // 输出:1 2.5 Hello A
return 0;
}

输出:

1
1 2.5 Hello A 

折叠表达式版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <iostream>

// 使用折叠表达式的printAll
template <typename... Args>
void printAll(const Args&... args) {
// 使用左折叠展开参数包,并在每个参数之后输出一个空格
((std::cout << args << " "), ...);
std::cout << std::endl;
}

int main() {
printAll(1, 2.5, "Hello", 'A'); // 输出:1 2.5 Hello A
return 0;
}

折叠表达式示例:计算总和

C++17 引入了折叠表达式,简化了参数包的处理。

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>

template <typename... Args>
auto sum(Args... args) -> decltype((args + ...)) {
return (args + ...); // 左折叠
}

int main() {
std::cout << sum(1, 2, 3, 4) << std::endl; // 输出:10
std::cout << sum(1.5, 2.5, 3.0) << std::endl; // 输出:7
return 0;
}

输出:

1
2
10
7

应用示例

示例:日志记录器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <iostream>
#include <string>

// 基础情况:无参数
void log(const std::string& msg) {
std::cout << msg << std::endl;
}

// 递归情况:至少一个参数
template <typename T, typename... Args>
void log(const std::string& msg, const T& first, const Args&... args) {
std::cout << msg << ": " << first << " ";
log("", args...); // 递归调用,省略消息前缀
}

int main() {
log("Error", 404, "Not Found");
// 输出:Error: 404 Not Found

log("Sum", 10, 20, 30);
// 输出:Sum: 10 20 30
return 0;
}

输出:

1
2
Error: 404 Not Found 
Sum: 10 20 30

要点:

  • 变参模板极大地提升了模板的灵活性。
  • 使用递归或折叠表达式处理参数包。
  • 常用于实现通用函数、容器类和元编程工具。

模板折叠(Fold Expressions)

1. 折叠表达式的概念与背景

在C++中,可变参数模板允许函数或类模板接受任意数量的模板参数。这在编写灵活且通用的代码时非常有用。然而,处理参数包中的每个参数往往需要递归模板技巧,这样的代码通常复杂且难以维护。

折叠表达式的引入显著简化了这一过程。它们允许开发者直接对参数包应用操作符,而无需手动展开或递归处理参数。这不仅使代码更加简洁,还提高了可读性和可维护性。

折叠表达式可分为:

  • 一元折叠表达式(Unary Fold):对参数包中的每个参数应用一个一元操作符。
  • 二元折叠表达式(Binary Fold):对参数包中的每个参数应用一个二元操作符。

此外,二元折叠表达式可进一步细分为**左折叠(Left Fold)**和**右折叠(Right Fold)**,取决于操作符的结合方向。

2. 一元折叠表达式(Unary Fold)

一元折叠表达式用于在参数包的每个参数前或后应用一元操作符。语法形式如下:

前置一元折叠(Unary Prefix Fold)

1
(op ... pack)

后置一元折叠(Unary Postfix Fold)

1
(pack ... op)

其中,op 是一元操作符,如!(逻辑非)、~(按位取反)等。

示例1:逻辑非操作

1
2
3
4
5
6
7
8
#include <iostream>

//对每个参数非操作,然后再将这些操作&&
//(!args && ...) 相当于 !a && !b && ...
template<typename... Args>
bool allNot(const Args&... args){
return (!args && ...);
}

3. 二元折叠表达式(Binary Fold)

二元折叠表达式用于在参数包的每个参数之间应用一个二元操作符。它们可以分为**二元左折叠(Binary Left Fold)**和**二元右折叠(Binary Right Fold)**,取决于操作符的结合方向。

二元折叠表达式语法

  • 二元左折叠(Left Fold):

    1
    (init op ... op pack)

    或者简化为:

    1
    (pack1 op ... op packN)
  • 二元右折叠(Right Fold):

    1
    (pack1 op ... op init op ...)

    或者简化为:

    1
    (pack1 op ... op packN)

其中,op 是二元操作符,如+、*、&&、||、<< 等。

左折叠与右折叠的区别

  • 二元左折叠(Binary Left Fold):操作符从左至右结合,等价于 (((a op b) op c) op d)。
  • 二元右折叠(Binary Right Fold):操作符从右至左结合,等价于 (a op (b op (c op d)))。

示例1:求和(Binary Left Fold)

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>

// 二元左折叠:((arg1 + arg2) + arg3) + ... + argN
template<typename... Args>
auto sumLeftFold(const Args&... args) {
return (args + ...); // 左折叠
}

int main() {
std::cout << sumLeftFold(1, 2, 3, 4) << std::endl; // 输出:10
return 0;
}

解释:

  • (args + ...) 是一个二元左折叠表达式。
  • 它将+操作符逐个应用于参数,按照左折叠顺序。
  • 即,((1 + 2) + 3) + 4 = 10。

示例2:乘积(Binary Right Fold)

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>

// 二元右折叠:arg1 * (arg2 * (arg3 * ... * argN))
template<typename... Args>
auto productRightFold(const Args&... args) {
return (... * args); // 右折叠
}

int main() {
std::cout << productRightFold(2, 3, 4) << std::endl; // 输出:24
return 0;
}

解释:

  • (... \* args) 是一个二元右折叠表达式。
  • 它将*操作符逐个应用于参数,按照右折叠顺序。
  • 即,2 * (3 * 4) = 2 * 12 = 24。

示例3:逻辑与(Binary Left Fold)

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

template<typename... Args>
bool allTrue(const Args&... args) {
return (args && ...); // 左折叠
}

int main() {
std::cout << std::boolalpha;
std::cout << allTrue(true, true, false) << std::endl; // 输出:false
std::cout << allTrue(true, true, true) << std::endl; // 输出:true
return 0;
}

解释:

  • (args && ...) 是一个二元左折叠表达式。
  • 用于检查所有参数是否为true。
  • 类似于链式的逻辑与运算。

4. 左折叠与右折叠(Left and Right Folds)

了解左折叠和右折叠的区别,对于正确选择折叠表达式的形式至关重要。

二元左折叠(Binary Left Fold)

  • 语法:

    1
    (args op ...)
  • 展开方式:

    1
    ((arg1 op arg2) op arg3) op ... op argN
  • 适用场景:

    • 当操作符是结合性的且从左侧开始累积操作时(如+、*)。
    • 需要严格的顺序执行时,确保从左到右依次处理参数。
  • 示例:

    1
    (args + ...) // 左折叠求和

二元右折叠(Binary Right Fold)

  • 语法:

    1
    (... op args)
  • 展开方式:

    1
    arg1 op (arg2 op (arg3 op ... op argN))
  • 适用场景:

    • 当操作符是右结合的,或当需要从右侧开始累积操作时。
    • 某些特定的逻辑和数据结构可能需要右侧先处理。
  • 示例:

    1
    (... + args) // 右折叠求和

嵌套折叠表达式

在某些复杂场景下,可能需要嵌套使用左折叠和右折叠,以达到特定的操作顺序。例如,结合多个不同的操作符。

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>

template<typename... Args>
auto complexFold(const Args&... args) {
// 先左折叠求和,然后右折叠求乘积
return (args + ...) * (... + args);
}

int main() {
std::cout << complexFold(1, 2, 3) << std::endl; // (1+2+3) * (1+2+3) = 6 * 6 = 36
return 0;
}

解释:

  • 在此示例中,我们首先对参数进行左折叠求和,然后对参数进行右折叠求和,最后将两者相乘。
  • 这种嵌套用途展示了折叠表达式的灵活性。

5. op 在折叠表达式中的作用

在折叠表达式中,op 代表二元操作符,用于定义如何将参数包中的各个参数相互结合。op 可以是任何合法的二元操作符,包括但不限于:

  • 算术操作符:+、-、*、/、% 等。
  • 逻辑操作符:&&、|| 等。
  • 按位操作符:&、|、^、<<、>> 等。
  • 比较操作符:==、!=、<、>、<=、>= 等。
  • 自定义操作符:如果定义了自定义类型并重载了特定的操作符,也可以使用这些操作符。

op 的选择直接影响折叠表达式的行为和结果。选择适当的操作符是实现特定功能的关键。

示例1:使用加法操作符

1
2
3
4
5
6
7
8
9
10
11
#include <iostream>

template<typename... Args>
auto addAll(const Args&... args) {
return (args + ...); // 使用 '+' 进行左折叠
}

int main() {
std::cout << addAll(1, 2, 3, 4) << std::endl; // 输出:10
return 0;
}

示例2:使用逻辑与操作符

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

template<typename... Args>
bool allTrue(const Args&... args) {
return (args && ...); // 使用 '&&' 进行左折叠
}

int main() {
std::cout << std::boolalpha;
std::cout << allTrue(true, true, false) << std::endl; // 输出:false
std::cout << allTrue(true, true, true) << std::endl; // 输出:true
return 0;
}

示例3:使用左移操作符(流插入)

1
2
3
4
5
6
7
8
9
10
11
#include <iostream>

template<typename... Args>
void printAll(const Args&... args) {
(std::cout << ... << args) << std::endl; // 使用 '<<' 进行左折叠
}

int main() {
printAll("Hello, ", "world", "!", 123); // 输出:Hello, world!123
return 0;
}

解释:

  • 在上述示例中,op 分别为 +、&&、<<。

  • 每个操作符定义了如何将参数包中的元素相互结合。

示例4:使用自定义操作符

假设有一个自定义类型Point,并重载了+操作符以支持点的相加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <iostream>

struct Point {
int x, y;

// 重载 '+' 操作符
Point operator+(const Point& other) const {
return Point{ x + other.x, y + other.y };
}
};

// 二元左折叠:((p1 + p2) + p3) + ... + pN
template<typename... Args>
Point sumPoints(const Args&... args) {
return (args + ...); // 使用 '+' 进行左折叠
}

int main() {
Point p1{1, 2}, p2{3, 4}, p3{5, 6};
Point result = sumPoints(p1, p2, p3);
std::cout << "Sum of Points: (" << result.x << ", " << result.y << ")\n"; // 输出:(9, 12)
return 0;
}

解释:

  • 通过重载+操作符,sumPoints函数能够将多个Point对象相加,得到累积的结果。

6. 示例代码与应用

为了全面理解折叠表达式的应用,以下提供多个具体示例,涵盖不同类型的折叠表达式。

示例1:字符串拼接

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>
#include <string>

template<typename... Args>
std::string concatenate(const Args&... args) {
return (std::string{} + ... + args); // 左折叠
}

int main() {
std::string result = concatenate("Hello, ", "world", "!", " Have a nice day.");
std::cout << result << std::endl; // 输出:Hello, world! Have a nice day.
return 0;
}

示例2:计算逻辑与

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

template<typename... Args>
bool areAllTrue(const Args&... args) {
return (args && ...); // 左折叠
}

int main() {
std::cout << std::boolalpha;
std::cout << areAllTrue(true, true, true) << std::endl; // 输出:true
std::cout << areAllTrue(true, false, true) << std::endl; // 输出:false
return 0;
}

示例3:计算最大值

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>
#include <algorithm>

template<typename T, typename... Args>
T maxAll(T first, Args... args) {
return (std::max)(first, ... , args); // 左折叠
}

int main() {
std::cout << maxAll(1, 5, 3, 9, 2) << std::endl; // 输出:9
return 0;
}

注意:上述示例中的(std::max)(first, ... , args)是一个非标准用法,需要根据具体情况调整。通常,std::max不支持直接的折叠表达式,因此此例更适合作为概念性说明。在实际应用中,可以使用std::initializer_list或其他方法实现多参数的最大值计算。

示例4:筛选逻辑

假设需要检查多个条件是否满足,且每个条件之间使用逻辑或操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

template<typename... Args>
bool anyTrue(const Args&... args) {
return (args || ...); // 左折叠
}

int main() {
std::cout << std::boolalpha;
std::cout << anyTrue(false, false, true) << std::endl; // 输出:true
std::cout << anyTrue(false, false, false) << std::endl; // 输出:false
return 0;
}

7. 注意事项与最佳实践

1. 操作符的选择

选择合适的操作符(op)对于实现正确的折叠行为至关重要。确保所选的操作符符合所需的逻辑和计算需求。

2. 操作符的结合性

不同的操作符具有不同的结合性(左结合、右结合)。了解操作符的结合性有助于选择正确的折叠方向(左折叠或右折叠)。

3. 参数包的初始化

在二元折叠表达式中,有时需要一个初始值(init)。这主要用于确保折叠的正确性,尤其在参数包可能为空的情况下。

示例:

1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>
#include <numeric>

template<typename... Args>
auto sumWithInit(int init, Args... args) {
return (init + ... + args); // 左折叠
}

int main() {
std::cout << sumWithInit(10, 1, 2, 3) << std::endl; // 输出:16 (10 + 1 + 2 + 3)
return 0;
}

4. 参数包为空的情况

如果参数包为空,折叠表达式的结果取决于折叠的类型和初始值。合理设置初始值可以避免潜在的错误。

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <iostream>

// 求和函数,如果参数包为空返回0
template<typename... Args>
auto sum(Args... args) {
return (0 + ... + args); // 左折叠,初始值为0
}

int main() {
std::cout << sum(1, 2, 3) << std::endl; // 输出:6
std::cout << sum() << std::endl; // 输出:0
return 0;
}

5. 与递归模板的比较

折叠表达式在处理可变参数模板时,比传统的递归模板方法更简洁、易读且易于维护。然而,理解折叠表达式的基本原理和语法对于充分利用其优势至关重要。

6. 编译器支持

确保所使用的编译器支持C++17或更高标准,因为折叠表达式是在C++17中引入的。常见的支持C++17的编译器包括:

  • GCC:从版本7开始支持C++17,其中完整支持在后续版本中得到增强。
  • Clang:从版本5开始支持C++17。
  • MSVC(Visual Studio):从Visual Studio 2017版本15.7开始提供较全面的C++17支持。

7. 性能考虑

折叠表达式本身并不引入额外的性能开销。它们是在编译时展开的,生成的代码与手动展开参数包时的代码几乎相同。然而,编写高效的折叠表达式仍然需要理解所应用操作符的性能特性。


SFINAE(Substitution Failure Is Not An Error)

一、什么是SFINAE?

SFINAE 是 “Substitution Failure Is Not An Error”(替换失败不是错误)的缩写,是C++模板编程中的一个重要概念。它允许编译器在模板实例化过程中,如果在替换模板参数时失败(即不满足某些条件),不会将其视为编译错误,而是继续寻找其他可能的模板或重载。这一机制为条件编译、类型特性检测、函数重载等提供了强大的支持。

二、SFINAE的工作原理

在模板实例化过程中,编译器会尝试将模板参数替换为具体类型。如果在替换过程中出现不合法的表达式或类型,编译器不会报错,而是将该模板视为不可行的,继续尝试其他模板或重载。这一特性允许开发者根据类型特性选择不同的模板实现。

三、SFINAE的应用场景

  1. 函数重载选择:根据参数类型的不同选择不同的函数实现。
  2. 类型特性检测:检测类型是否具有某些成员或特性,从而决定是否启用某些功能。
  3. 条件编译:根据模板参数的特性决定是否编译某些代码段。

四、SFINAE的基本用法

SFINAE通常与std::enable_if、模板特化、以及类型萃取等技术结合使用。以下通过几个例子来说明SFINAE的应用。

示例一:通过std::enable_if实现函数重载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <type_traits>
#include <iostream>

// 适用于整数类型
template <typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type
print_type(T value) {
std::cout << "Integral type: " << value << std::endl;
}

// 适用于浮点类型
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, void>::type
print_type(T value) {
std::cout << "Floating point type: " << value << std::endl;
}

int main() {
print_type(10); // 输出: Integral type: 10
print_type(3.14); // 输出: Floating point type: 3.14
// print_type("Hello"); // 编译错误,没有匹配的函数
return 0;
}

解释:

  • std::enable_if 根据条件 std::is_integral<T>::value 或 std::is_floating_point<T>::value 决定是否启用对应的函数模板。
  • 当条件不满足时,该模板实例化失败,但由于SFINAE规则,编译器不会报错,而是忽略该模板,从而实现函数重载选择。

示例二:检测类型是否具有特定成员

假设我们需要实现一个函数,仅当类型 T 具有成员函数 foo 时才启用该函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <type_traits>
#include <iostream>

// 辅助类型,检测是否存在成员函数 foo
template <typename T>
class has_foo {
private:
typedef char yes[1];
typedef char no[2];

template <typename U, void (U::*)()>
struct SFINAE {};

template <typename U>
static yes& test(SFINAE<U, &U::foo>*);

template <typename U>
static no& test(...);

public:
static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes);
};

// 函数仅在 T 有 foo() 成员时启用
template <typename T>
typename std::enable_if<has_foo<T>::value, void>::type
call_foo(T& obj) {
obj.foo();
std::cout << "foo() called." << std::endl;
}

class WithFoo {
public:
void foo() { std::cout << "WithFoo::foo()" << std::endl; }
};

class WithoutFoo {};

int main() {
WithFoo wf;
call_foo(wf); // 输出: WithFoo::foo() \n foo() called.

// WithoutFoo wf2;
// call_foo(wf2); // 编译错误,没有匹配的函数
return 0;
}

解释:

  • has_foo 是一个类型萃取类,用于检测类型 T 是否具有成员函数 foo。
  • call_foo 函数模板仅在 T 具有 foo 成员时启用。
  • 对于不具有 foo 成员的类型,编译器会忽略 call_foo,从而避免编译错误。

示例三:通过模板特化实现不同的行为

以下是完整的、正确实现 TypePrinter 的代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <type_traits>
#include <iostream>

// 1. 定义一个 Trait 用于检测 T 是否有非 void 的 `value_type`
template <typename T, typename = void>
struct has_non_void_value_type : std::false_type {};

// 仅当 T 有 `value_type` 且 `value_type` 不是 void 时,特化为 std::true_type
template <typename T>
struct has_non_void_value_type<T, std::enable_if_t<!std::is_void_v<typename T::value_type>>> : std::true_type {};

// 2. 定义 TypePrinter 主模板,使用一个布尔参数控制特化
template <typename T, bool HasValueType = has_non_void_value_type<T>::value>
struct TypePrinter;

// 3. 特化:当 HasValueType 为 true 时,表示 T 有非 void 的 `value_type`
template <typename T>
struct TypePrinter<T, true> {
static void print(){
std::cout << "T has a member type 'value_type'." << std::endl;
}
};

// 特化:当 HasValueType 为 false 时,表示 T 没有 `value_type` 或 `value_type` 是 void
template <typename T>
struct TypePrinter<T, false> {
static void print(){
std::cout << "hello world! T does not have a member type 'value_type'." << std::endl;
}
};

// 测试结构体
struct WithValueType{
using value_type = int;
};

struct WithoutValueType{};

struct WithVoidValueType{
using value_type = void;
};

int main() {
TypePrinter<WithValueType>::print(); // 输出: T has a member type 'value_type'.
TypePrinter<WithoutValueType>::print(); // 输出: hello world! T does not have a member type 'value_type'.
TypePrinter<WithVoidValueType>::print(); // 输出: hello world! T does not have a member type 'value_type'.
return 0;
}

代码解释

  1. Trait has_non_void_value_type:
    • 主模板:默认情况下,has_non_void_value_type<T> 继承自 std::false_type,表示 T 没有 value_type 或 value_type 是 void。
    • 特化模板:仅当 T 有 value_type 且 value_type 不是 void 时,has_non_void_value_type<T> 继承自 std::true_type。
  2. TypePrinter 模板:
    • 主模板:接受一个类型 T 和一个布尔模板参数 HasValueType,默认为 has_non_void_value_type<T>::value。
    • **特化版本 TypePrinter<T, true>**:当 HasValueType 为 true 时,表示 T 有非 void 的 value_type,提供相应的 print 实现。
    • **特化版本 TypePrinter<T, false>**:当 HasValueType 为 false 时,表示 T 没有 value_type 或 value_type 是 void,提供默认的 print 实现。
  3. 测试结构体:
    • WithValueType:有一个非 void 的 value_type。
    • WithoutValueType:没有 value_type。
    • WithVoidValueType:有一个 value_type,但它是 void。
  4. main 函数:
    • 分别测试了三种情况,验证 TypePrinter 的行为是否符合预期。

五、SFINAE的优缺点

优点:

  1. 灵活性高:能够根据类型特性选择不同的实现,提升代码的泛化能力。
  2. 类型安全:通过编译期检测,避免了运行时错误。
  3. 无需额外的运行时开销:所有的类型筛选都在编译期完成。

缺点:

  1. 复杂性高:SFINAE相关的代码往往较为复杂,阅读和维护难度较大。
  2. 编译器错误信息难以理解:SFINAE失败时,编译器可能给出晦涩的错误信息,调试困难。
  3. 模板实例化深度限制:过度使用SFINAE可能导致编译时间增加和模板实例化深度限制问题。

六、现代C++中的替代方案

随着C++11及后续标准的发展,引入了诸如decltype、constexpr、if constexpr、概念(C++20)等新的特性,部分情况下可以替代传统的SFINAE,提高代码的可读性和可维护性。例如,C++20引入的概念(Concepts)提供了更为简洁和直观的方式来约束模板参数,减少了SFINAE的复杂性。

示例:使用概念替代SFINAE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <concepts>
#include <iostream>

// 定义一个概念,要求类型 T 是整数类型
template <typename T>
concept Integral = std::is_integral_v<T>;

// 仅当 T 满足 Integral 概念时启用
template <Integral T>
void print_type(T value) {
std::cout << "Integral type: " << value << std::endl;
}

int main() {
print_type(42); // 输出: Integral type: 42
// print_type(3.14); // 编译错误,不满足 Integral 概念
return 0;
}

解释:

  • 使用概念Integral代替std::enable_if,语法更简洁,代码更易读。
  • 当类型不满足概念时,编译器会给出明确的错误信息,便于调试。

虽然上述方法经典且有效,但在C++11及以后版本,存在更简洁和易读的方式来实现相同的功能。例如,使用std::void_t和更现代的检测技巧,或者直接使用C++20的概念(Concepts),使代码更加清晰。

示例:使用std::void_t简化has_foo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <type_traits>
#include <iostream>

// 使用 std::void_t 简化 has_foo
template <typename, typename = std::void_t<>>
struct has_foo : std::false_type {};

template <typename T>
struct has_foo<T, std::void_t<decltype(std::declval<T>().foo())>> : std::true_type {};

// 函数仅在 T 有 foo() 成员时启用
template <typename T>
std::enable_if_t<has_foo<T>::value, void>
call_foo(T& obj) {
obj.foo();
std::cout << "foo() called." << std::endl;
}

class WithFoo {
public:
void foo() { std::cout << "WithFoo::foo()" << std::endl; }
};

class WithoutFoo {};

int main() {
WithFoo wf;
call_foo(wf); // 输出: WithFoo::foo()
// foo() called.

// WithoutFoo wf2;
// call_foo(wf2); // 编译错误,没有匹配的函数
return 0;
}

解释:

  • 利用std::void_t,has_foo结构更为简洁。
  • decltype(std::declval<T>().foo())尝试在不实例化T对象的情况下检测foo()成员函数。
  • 如果foo()存在,has_foo<T>继承自std::true_type,否则继承自std::false_type。

使用C++20概念

如果你使用的是支持C++20的编译器,可以利用概念(Concepts)进一步简化和增强可读性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <concepts>
#include <type_traits>
#include <iostream>

// 定义一个概念,要求类型 T 具有 void foo()
template <typename T>
concept HasFoo = requires(T t) {
{ t.foo() } -> std::same_as<void>;
};

// 仅当 T 满足 HasFoo 概念时启用
template <HasFoo T>
void call_foo(T& obj) {
obj.foo();
std::cout << "foo() called." << std::endl;
}

class WithFoo {
public:
void foo() { std::cout << "WithFoo::foo()" << std::endl; }
};

class WithoutFoo {};

int main() {
WithFoo wf;
call_foo(wf); // 输出: WithFoo::foo()
// foo() called.

// WithoutFoo wf2;
// call_foo(wf2); // 编译错误,不满足 HasFoo 概念
return 0;
}

解释:

  • HasFoo概念:使用requires表达式检测类型T是否具有void foo()成员函数。
  • call_foo函数模板:仅当T满足HasFoo概念时,模板被启用。
  • 这种方式更直观,易于理解和维护。

七、总结

SFINAE作为C++模板编程中的一项强大功能,通过在模板实例化过程中允许替换失败而不报错,实现了基于类型特性的编程。然而,SFINAE的语法复杂且难以维护,现代C++引入的新特性如概念等在某些情况下已经能够更简洁地实现类似的功能。尽管如此,理解SFINAE的工作机制依然对于掌握高级模板技术和阅读老旧代码具有重要意义。


综合案例:结合模板特化与折叠表达式

为了进一步巩固对模板特化和折叠表达式的理解,本节将通过一个综合案例展示如何将两者结合使用。

案例描述

实现一个通用的日志记录器Logger,能够处理任意数量和类型的参数,并根据不同的类型组合调整输出格式。具体需求包括:

  1. 对于普通类型,使用通用的打印格式。
  2. 对于指针类型,打印指针地址或指向的值。
  3. 对于std::string类型,使用专门的格式。
  4. 支持可变数量的参数,通过折叠表达式实现参数的逐一打印。

实现步骤

  1. **定义通用类模板Logger**,使用模板特化和偏特化处理不同类型。
  2. 实现log函数,使用模板折叠表达式逐一打印参数。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <iostream>
#include <string>
#include <type_traits>

// 通用类模板
template <typename T, typename Enable = void>
class Logger {
public:
static void log(const T& value) {
std::cout << "General Logger: " << value << std::endl;
}
};

// 类模板偏特化:当 T 是指针类型
template <typename T>
class Logger<T, typename std::enable_if<std::is_pointer<T>::value>::type> {
public:
static void log(T value) {
if (value) {
std::cout << "Pointer Logger: " << *value << std::endl;
} else {
std::cout << "Pointer Logger: nullptr" << std::endl;
}
}
};

// 类模板全特化:当 T 是 std::string
template <>
class Logger<std::string> {
public:
static void log(const std::string& value) {
std::cout << "String Logger: \"" << value << "\"" << std::endl;
}
};

// 函数模板,用于递归调用 Logger::log
template <typename T>
void logOne(const T& value) {
Logger<T>::log(value);
}

// 使用模板折叠表达式实现多参数日志记录
template <typename... Args>
void logAll(const Args&... args) {
(logOne(args), ...); // 左折叠,调用 logOne 对每个参数进行日志记录
}

int main() {
int a = 10;
double b = 3.14;
std::string s = "Hello, World!";
int* ptr = &a;
double* pNull = nullptr;

// 使用 Logger 类模板进行特化打印
Logger<int>::log(a); // 输出:General Logger: 10
Logger<double*>::log(pNull); // 输出:Pointer Logger: nullptr
Logger<std::string>::log(s); // 输出:String Logger: "Hello, World!"

std::cout << "\nLogging multiple parameters:" << std::endl;
logAll(a, b, s, ptr, pNull);
/*
输出:
General Logger: 10
General Logger: 3.14
String Logger: "Hello, World!"
Pointer Logger: 10
Pointer Logger: nullptr
*/

return 0;
}

输出:

1
2
3
4
5
6
7
8
9
10
General Logger: 10
Pointer Logger: nullptr
String Logger: "Hello, World!"

Logging multiple parameters:
General Logger: 10
General Logger: 3.14
String Logger: "Hello, World!"
Pointer Logger: 10
Pointer Logger: nullptr

解析

  1. **通用模板Logger<T, Enable>**:
    • 使用第二个模板参数Enable与SFINAE(Substitution Failure Is Not An Error)结合,控制模板特化。
    • 对于非指针类型和非std::string类型,使用通用实现,打印"General Logger: value"。
  2. **类模板偏特化Logger<T, Enable>**:
    • 使用std::enable_if和std::is_pointer,当T是指针类型时,特化模板。
    • 实现指针类型的特殊日志处理,打印指针指向的值或nullptr。
  3. **类模板全特化Logger<std::string>**:
    • 为std::string类型提供全特化版本,使用不同的输出格式。
  4. logOne函数模板:
    • 简化调用过程,调用相应的Logger<T>::log方法。
  5. logAll函数模板:
    • 使用模板折叠表达式(logOne(args), ...),实现对所有参数的逐一日志记录。
    • 通过左折叠的逗号表达式,确保每个logOne调用依次执行。
  6. main函数:
    • 测试不同类型的日志记录,包括普通类型、指针类型和std::string类型。
    • 调用logAll函数,实现多参数的综合日志记录。

模板元编程(Template Metaprogramming)

  • 什么是模板元编程:模板元编程(Template Metaprogramming)是一种在编译期通过模板机制进行代码生成和计算的编程技术。它利用编译器的模板实例化机制,在编译期间执行代码逻辑,以提高程序的性能和灵活性。
  • 模板元编程的优势:
    • 提高代码的可重用性和泛化能力。
    • 在编译期进行复杂计算,减少运行时开销。
    • 实现类型安全的高级抽象。

模板元编程基础

  • 模板特化(Template Specialization):
    • 全特化(Full Specialization):为特定类型提供特定实现。
    • 偏特化(Partial Specialization):为部分模板参数特定的情况提供实现。
  • 递归模板(Recursive Templates):利用模板的递归实例化机制,实现编译期计算。

编译期计算

模板元编程允许在编译时执行计算,如计算阶乘、斐波那契数列等。

示例:编译期阶乘

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <iostream>

// 基础情况
template <int N>
struct Factorial {
static const int value = N * Factorial<N - 1>::value;
};

// 递归终止
template <>
struct Factorial<0> {
static const int value = 1;
};

int main() {
std::cout << "5! = " << Factorial<5>::value << std::endl; // 输出:5! = 120
std::cout << "0! = " << Factorial<0>::value << std::endl; // 输出:0! = 1
return 0;
}

输出:

1
2
5! = 120
0! = 1

讲解:

  1. 基本模板 Factorial定义了一个静态常量value,其值为N * Factorial<N - 1>::value,实现递归计算。
  2. 特化模板 Factorial<0>定义递归终止条件,当N=0时,value为1。
  3. 在main函数中,通过Factorial<5>::value获取5的阶乘结果,编译期即生成其值。

静态成员变量的基本规则

在 C++ 中,静态成员变量的声明与定义有以下基本规则:

  1. 声明(Declaration):在类内部声明静态成员变量,告诉编译器该类包含这个静态成员。
  2. 定义(Definition):在类外部对静态成员变量进行定义,分配存储空间。

通常,对于非 constexpr 或非 inline 的静态成员变量,必须 在类外进行定义,否则会导致链接器错误(undefined reference)。

特殊情况:static const 整数成员

对于 static const 整数类型 的静态成员变量,C++ 标准做了一些特殊的处理:

  • 类内初始化:你可以在类内部初始化 static const 整数成员变量,例如 static const int value = 42;。

  • 使用场景

    :

    • 不需要类外定义:在某些情况下,编译器在编译阶段可以直接使用类内的初始化值,无需类外定义。
    • 需要类外定义:如果你在程序中对该静态成员变量进行取址(例如,&Factorial<5>::value),或者在其他需要该变量的存储位置时,就需要在类外进行定义。

C++11 及之前的标准

在 C++11 及更早的标准中,对于 static const 整数成员变量:

  • 不需要类外定义的情况

    :

    • 仅在作为编译期常量使用时,不需要类外定义。例如,用于数组大小、模板参数等。
  • 需要类外定义的情况

    :

    • 当你需要对变量进行取址,或者在需要其存储位置时,必须在类外定义。例如:

      1
      2
      3
      4
      5
      6
      7
      8
      template<int N>
      struct Factorial{
      static const int value = N * Factorial<N-1>::value;
      };

      // 类外定义
      template<int N>
      const int Factorial<N>::value;

C++17 及更新标准

从 C++17 开始,引入了 内联变量(inline variables),使得在类内定义静态成员变量变得更加灵活:

  • 内联静态成员变量

    :

    • 使用 inline 关键字,可以在类内对静态成员变量进行定义,无需在类外进行单独定义。
    • 这适用于 C++17 及更高版本。

例如,你可以这样编写:

1
2
3
4
5
6
7
8
9
template<int N>
struct Factorial{
inline static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
inline static const int value = 1;
};

在这种情况下,无需在类外进行定义,因为 inline 确保了该变量在每个翻译单元中都只有一个实例。

在 C++11 及之前的标准

代码:

1
2
3
4
5
6
7
8
9
template<int N>
struct Factorial{
static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
static const int value = 1;
};
  • 作为编译期常量使用

    :

    • 例如,用于其他模板参数或编译期常量计算时,不需要类外定义。
  • 取址或需要存储位置时

    :

    • 需要在类外进行定义。例如:

      1
      2
      3
      4
      5
      template<int N>
      const int Factorial<N>::value;

      template<>
      const int Factorial<0>::value;

在 C++17 及更高标准

如果你使用 C++17 及更高版本,可以使用 inline 关键字:

1
2
3
4
5
6
7
8
9
template<int N>
struct Factorial{
inline static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
inline static const int value = 1;
};
  • 无需类外定义

    :

    • inline 使得在类内的定义成为唯一的定义,即使在多个翻译单元中使用,也不会导致重复定义错误。

实际示例与测试

示例 1:仅作为编译期常量使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <iostream>

// 你的 Factorial 模板
template<int N>
struct Factorial{
static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
static const int value = 1;
};

int main() {
std::cout << "Factorial<5>::value = " << Factorial<5>::value << std::endl;
return 0;
}
  • C++11 及之前:无需类外定义。
  • C++17 及更新:同样无需类外定义,且可以使用 inline 进一步优化。

示例 2:取址

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <iostream>

// 你的 Factorial 模板
template<int N>
struct Factorial{
static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
static const int value = 1;
};

// 类外定义(在 C++11 及之前需要)
template<int N>
const int Factorial<N>::value;

template<>
const int Factorial<0>::value;

int main() {
std::cout << "Factorial<5>::value = " << Factorial<5>::value << std::endl;
std::cout << "&Factorial<5>::value = " << &Factorial<5>::value << std::endl;
return 0;
}
  • C++11 及之前:必须提供类外定义,否则会在链接时出现错误。
  • C++17 及更新:若未使用 inline,仍需提供类外定义;使用 inline 则无需。

示例 3:使用 inline(C++17 及更高)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <iostream>

// 你的 Factorial 模板(使用 inline)
template<int N>
struct Factorial{
inline static const int value = N * Factorial<N-1>::value;
};

template<>
struct Factorial<0>{
inline static const int value = 1;
};

int main() {
std::cout << "Factorial<5>::value = " << Factorial<5>::value << std::endl;
std::cout << "&Factorial<5>::value = " << &Factorial<5>::value << std::endl;
return 0;
}
  • C++17 及以上:
    • 无需类外定义。
    • inline 保证了多重定义的合法性。

详细解析

为什么有这样的特殊处理?

  • 优化与性能

    :

    • 在编译期常量的情况下,不需要在运行时分配存储空间,编译器可以优化掉相关代码。
  • 兼容性

    :

    • 早期 C++ 标准遵循这种规则,允许在类内初始化静态常量成员变量,便于模板元编程和常量表达式的使用。
  • inline 变量

    :

    • C++17 引入 inline 关键字用于变量,解决了静态成员变量在多个翻译单元中的定义问题,使得代码更简洁。

是否总是需要定义?

并非总是需要。关键在于 如何使用 这个静态成员变量:

  • 仅作为编译期常量使用:无需类外定义。
  • 需要存储位置或取址:需要类外定义,除非使用 inline(C++17 及以上)。

编译器与链接器的行为

  • 编译阶段

    :

    • 类内的初始化用于编译期常量计算,不涉及存储分配。
  • 链接阶段

    :

    • 如果没有类外定义,且静态成员被 odr-used(可能需要存储位置),链接器会报错,提示找不到符号定义。
    • 使用 inline 关键字后,编译器处理为内联变量,避免了多重定义问题。

示例:编译期斐波那契数列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <iostream>

// 基础情况
template <int N>
struct Fibonacci {
static const long long value = Fibonacci<N - 1>::value + Fibonacci<N - 2>::value;
};

// 递归终止
template <>
struct Fibonacci<0> {
static const long long value = 0;
};

template <>
struct Fibonacci<1> {
static const long long value = 1;
};

int main() {
std::cout << "Fibonacci<10> = " << Fibonacci<10>::value << std::endl; // 输出:Fibonacci<10> = 55
std::cout << "Fibonacci<20> = " << Fibonacci<20>::value << std::endl; // 输出:Fibonacci<20> = 6765
return 0;
}

输出:

1
2
Fibonacci<10> = 55
Fibonacci<20> = 6765

要点:

  • 模板元编程利用编译期计算提升程序性能。
  • 需要理解模板递归与终止条件。
  • 常与类型特性和模板特化结合使用。

类型计算与SFINAE

  • 类型计算:在编译期进行类型的推导和转换。
  • SFINAE(Substitution Failure Is Not An Error):模板实例化过程中,如果某个替换失败,编译器不会报错,而是忽略该模板,并尝试其他匹配。

示例:检测类型是否可加

1
2
3
4
5
6
7
8
9
10
11
12
#include <type_traits>

// 检测是否可以对T类型进行加法操作
template <typename T, typename = void>
struct is_addable : std::false_type {};

template <typename T>
struct is_addable<T, decltype(void(std::declval<T>() + std::declval<T>()))> : std::true_type {};

// 使用
static_assert(is_addable<int>::value, "int should be addable");
static_assert(!is_addable<void*>::value, "void* should not be addable");

讲解:

1. struct is_addable<...> : std::true_type {}

  • 目的:定义一个名为 is_addable 的结构体模板,它继承自 std::true_type。
  • 作用:当特定的模板参数满足条件时,这个特化版本将被选中,表示 T 类型是可加的,即支持 + 操作符。

2. 模板参数解释:<T, decltype(void(std::declval<T>() + std::declval<T>()))>

  • **T**:这是要检查的类型。
  • **std::declval<T>()**:
    • 用途:std::declval<T>() 是一个用于在不实际创建 T 类型对象的情况下,生成一个 T 类型的右值引用。
    • 作用:它允许我们在编译时模拟 T 类型的对象,以便用于表达式的检测。
  • **std::declval<T>() + std::declval<T>()**:
    • 表达式:尝试对两个 T 类型的右值引用进行加法运算。
    • 目的:检查 T 类型是否支持 + 操作符。
  • **void(...)**:
    • 将加法表达式的结果转换为 void 类型。这是为了在 decltype 中仅关心表达式是否有效,而不关心其具体类型。
  • **decltype(void(std::declval<T>() + std::declval<T>()))**:
    • 作用:如果 T 类型支持加法运算,则该 decltype 表达式的类型为 void,否则会导致替换失败

高级模板元编程技巧

  • 变参模板(Variadic Templates):支持模板参数包,实现更加灵活的模板定义。

示例:求和模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 基本递归模板
template <int... Ns>
struct Sum;

// 递归终止
template <>
struct Sum<> {
static const int value = 0;
};

// 递归定义
template <int N, int... Ns>
struct Sum<N, Ns...> {
static const int value = N + Sum<Ns...>::value;
};

// 使用
int main() {
int result = Sum<1, 2, 3, 4, 5>::value; // 15
return 0;
}

讲解:

  1. 基本模板 Sum接受一个整数参数包Ns...。
  2. 特化模板 Sum<>定义递归终止条件,value为0。
  3. 递归定义 Sum<N, Ns...>将第一个参数N与剩余参数的和相加。
  4. 在main函数中,通过Sum<1, 2, 3, 4, 5>::value计算1+2+3+4+5=15。
  • 类型列表(Type Lists):通过模板参数包管理类型的集合。

示例:类型列表和元素访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 定义类型列表
template <typename... Ts>
struct TypeList {};

// 获取类型列表中第N个类型
template <typename List, std::size_t N>
struct TypeAt;

template <typename Head, typename... Tail>
struct TypeAt<TypeList<Head, Tail...>, 0> {
using type = Head;
};

template <typename Head, typename... Tail, std::size_t N>
struct TypeAt<TypeList<Head, Tail...>, N> {
using type = typename TypeAt<TypeList<Tail...>, N - 1>::type;
};

// 使用
using list = TypeList<int, double, char>;
using third_type = TypeAt<list, 2>::type; // char

讲解:

  1. **TypeList**:定义一个包含多个类型的类型列表。
  2. TypeAt:通过递归模板,从TypeList中获取第N个类型。
    • 当N为0时,类型为Head。
    • 否则,递归获取Tail...中第N-1个类型。
  3. 使用:定义list为TypeList<int, double, char>,third_type为第2个类型,即char。

实际应用案例

案例1:静态断言与类型检查

1
2
3
4
5
6
7
8
9
10
#include <type_traits>

template <typename T>
struct is_integral_type {
static const bool value = std::is_integral<T>::value;
};

// 使用
static_assert(is_integral_type<int>::value, "int is integral");
static_assert(!is_integral_type<float>::value, "float is not integral");

案例2:编译期字符串

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <utility>

// 编译期字符串
template <char... Cs>
struct String {
static constexpr char value[sizeof...(Cs) + 1] = { Cs..., '\0' };
};

template <char... Cs>
constexpr char String<Cs...>::value[sizeof...(Cs) + 1];

// 使用
using hello = String<'H','e','l','l','o'>;

为什么需要外部定义 value

在 C++ 中,静态成员变量与类的实例无关,它们存在于全局命名空间中。然而,静态成员变量的声明和定义是不同的:

  1. 声明:告诉编译器类中存在这个变量。
  2. 定义:为这个变量分配存储空间。

对于非 inline 的静态成员变量,即使是 constexpr,都需要在类外部进行定义。否则,链接器在处理多个翻译单元时会因为找不到变量的定义而报错。

具体原因

  1. 模板类的静态成员变量:
    • 每当模板实例化时,都会产生一个新的类类型,每个类类型都有自己的一组静态成员变量。
    • 因此,编译器需要知道这些静态成员变量在所有翻译单元中都唯一对应一个定义。
  2. constexpr 静态成员变量:
    • 从 C++17 开始,inline 关键字引入,使得 constexpr 静态成员变量可以在类内定义,并且隐式地具有 inline 属性。这意味着不需要在类外定义它们,因为 inline 确保了在多个翻译单元中有同一份定义。
    • 但在 C++17 之前或不使用 inline 的情况下,即使是 constexpr,仍需在类外定义。
  • 类内声明:static constexpr char value[...] 声明了 value 并给予了初始值。
  • 类外定义:constexpr char String<Cs...>::value[...] 为 value 分配了存储空间。

如果省略类外定义,编译器会在链接阶段找不到 value 的定义,导致链接错误。这尤其适用于 C++14 及更早版本,以及 C++17 中未使用 inline 的情形。

如何避免外部定义

如果你使用的是 C++17 或更高版本,可以通过 inline 关键字将静态成员变量声明为 inline,从而在类内完成定义,无需再在外部定义。例如:

1
2
3
4
5
6
7
8
9
10
#include <utility>

// 编译期字符串
template <char... Cs>
struct String {
inline static constexpr char value[sizeof...(Cs) + 1] = { Cs..., '\0' };
};

// 使用
using hello = String<'H','e','l','l','o'>;

在这个版本中,inline 关键字告诉编译器这是一个内联变量,允许在多个翻译单元中存在同一份定义,而不会导致重复定义错误。因此,无需在类外再次定义 value。

完整示例对比

不使用 inline(需要类外定义)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <utility>

// 编译期字符串
template <char... Cs>
struct String {
static constexpr char value[sizeof...(Cs) + 1] = { Cs..., '\0' };
};

// 外部定义
template <char... Cs>
constexpr char String<Cs...>::value[sizeof...(Cs) + 1];

// 使用
using hello = String<'H','e','l','l','o'>;

int main() {
// 访问 value
// std::cout << hello::value;
}

使用 inline(无需类外定义,C++17 起)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <utility>

// 编译期字符串
template <char... Cs>
struct String {
inline static constexpr char value[sizeof...(Cs) + 1] = { Cs..., '\0' };
};

// 使用
using hello = String<'H','e','l','l','o'>;

int main() {
// 访问 value
// std::cout << hello::value;
}

C++20 Concepts

C++20 引入了 Concepts,它们为模板参数提供了更强的约束和表达能力,使模板的使用更简洁、错误信息更友好。

定义与使用

定义一个 Concept

Concepts 使用 concept 关键字定义,并作为函数或类模板的约束。

1
2
3
4
5
6
7
8
#include <concepts>
#include <iostream>

// 定义一个 Concept:要求类型必须是可输出到 std::ostream
template <typename T>
concept Printable = requires(T a) {
{ std::cout << a } -> std::same_as<std::ostream&>;
};

使用 Concept 约束模板

1
2
3
4
5
6
7
8
9
10
11
12
// 使用 Concepts 约束函数模板
template <Printable T>
void print(const T& value) {
std::cout << value << std::endl;
}

int main() {
print(42); // 正常调用
print("Hello"); // 正常调用
// print(std::vector<int>{1, 2, 3}); // 编译错误,std::vector<int> 不满足 Printable
return 0;
}

限制与约束

Concepts 允许为模板参数定义复杂的约束,使得模板更具表达性,同时提升编译器错误信息的可理解性。

示例:排序函数中的 Concepts

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <concepts>
#include <vector>
#include <iostream>
#include <algorithm>

// 定义一个可比较的概念
template <typename T>
concept Comparable = requires(T a, T b) {
{ a < b } -> std::convertible_to<bool>;
};

// 排序函数,约束类型必须可比较
template <Comparable T>
void sortVector(std::vector<T>& vec) {
std::sort(vec.begin(), vec.end());
}

int main() {
std::vector<int> nums = {4, 2, 3, 1};
sortVector(nums);
for(auto num : nums)
std::cout << num << " "; // 输出:1 2 3 4
std::cout << std::endl;

// std::vector<std::vector<int>> vecs;
// sortVector(vecs); // 编译错误,std::vector<int> 不满足 Comparable
return 0;
}

输出:

1
1 2 3 4 

要点:

  • Concepts 提供了模板参数的语义约束。
  • 使用 Concepts 提高模板的可读性和可维护性。
  • 生成更友好的编译错误信息,易于调试。

模板实例化与编译器行为

理解模板实例化的过程有助于进行有效的模板设计与优化,尤其是在涉及链接和编译时间时。

显式实例化(Explicit Instantiation)

显式实例化告诉编译器生成特定类型下的模板代码,主要用于分离模板的声明与定义,减少编译时间。

语法:

1
2
3
4
5
6
7
8
9
10
// 声明模板(通常在头文件中)
template <typename T>
class MyClass;

// 定义模板(通常在源文件中)
template <typename T>
class MyClass { /* ... */ };

// 显式实例化
template class MyClass<int>;

示例:分离类模板的声明与定义

MyClass.h

1
2
3
4
5
6
7
8
9
10
#ifndef MYCLASS_H
#define MYCLASS_H

template <typename T>
class MyClass {
public:
void doSomething();
};

#endif // MYCLASS_H

MyClass.cpp

1
2
3
4
5
6
7
8
9
10
11
#include "MyClass.h"
#include <iostream>

template <typename T>
void MyClass<T>::doSomething() {
std::cout << "Doing something with " << typeid(T).name() << std::endl;
}

// 显式实例化
template class MyClass<int>;
template class MyClass<double>;

main.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
#include "MyClass.h"

int main() {
MyClass<int> obj1;
obj1.doSomething(); // 输出:Doing something with i

MyClass<double> obj2;
obj2.doSomething(); // 输出:Doing something with d

// MyClass<char> obj3; // 链接错误,因为 MyClass<char> 未实例化

return 0;
}

输出:

1
2
Doing something with i
Doing something with d

注意事项:

  • 显式实例化需要在模板定义后进行。
  • 只有显式实例化的类型在未实例化时可用于模板分离。
  • 未显式实例化的类型可能导致链接错误。

隐式实例化(Implicit Instantiation)

隐式实例化是编译器在模板被实际使用时自动生成对应实例代码的过程。通常,模板定义与使用都在头文件中完成。

示例:

MyClass.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#ifndef MYCLASS_H
#define MYCLASS_H

#include <iostream>
#include <typeinfo>

template <typename T>
class MyClass {
public:
void doSomething() {
std::cout << "Doing something with " << typeid(T).name() << std::endl;
}
};

#endif // MYCLASS_H

main.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include "MyClass.h"

int main() {
MyClass<int> obj1;
obj1.doSomething(); // 输出:Doing something with i

MyClass<double> obj2;
obj2.doSomething(); // 输出:Doing something with d

MyClass<char> obj3;
obj3.doSomething(); // 输出:Doing something with c

return 0;
}

输出:

1
2
3
Doing something with i
Doing something with d
Doing something with c

要点:

  • 隐式实例化不需要显式声明或定义。
  • 模板定义必须在使用前可见,通常通过头文件实现。
  • 容易导致编译时间增加,尤其是大型模板库。

链接时问题与解决方案

由于模板是在使用时被实例化,跨源文件使用模板可能导致链接时问题,如重复定义或未定义引用。

解决方案:

  1. 内联实现:将模板的定义与声明一起放在头文件中,避免链接时重复定义。
  2. 显式实例化:将常用的模板实例化放在源文件中,其他源文件通过 extern 或头文件引用已有实例。
  3. **使用 extern template**:告知编译器某些模板实例已在其他源文件中显式实例化。

示例:使用 extern template

MyClass.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef MYCLASS_H
#define MYCLASS_H

template <typename T>
class MyClass {
public:
void doSomething();
};

// 声明模板实例,但不定义
extern template class MyClass<int>;
extern template class MyClass<double>;

#endif // MYCLASS_H

MyClass.cpp

1
2
3
4
5
6
7
8
9
10
11
12
#include "MyClass.h"
#include <iostream>
#include <typeinfo>

template <typename T>
void MyClass<T>::doSomething() {
std::cout << "Doing something with " << typeid(T).name() << std::endl;
}

// 显式实例化
template class MyClass<int>;
template class MyClass<double>;

main.cpp

1
2
3
4
5
6
7
8
9
10
11
12
#include "MyClass.h"

int main() {
MyClass<int> obj1;
obj1.doSomething(); // 使用已显式实例化的模板

MyClass<double> obj2;
obj2.doSomething(); // 使用已显式实例化的模板

// MyClass<char> obj3; // 链接错误,未实例化
return 0;
}

要点:

  • 使用 extern template 声明已在其他源文件中实例化的模板。
  • 减少编译时间和链接大小,防止重复定义。

最佳实践与注意事项

掌握模板的最佳实践有助于编写高效、可维护的代码,同时避免常见的陷阱和问题。

模板定义与实现分离

对于类模板,通常将模板的声明和定义放在同一头文件中,以确保编译器在实例化模板时能够看到完整的定义。尽管可以尝试将模板定义分离到源文件,但需要结合显式实例化,这会增加复杂性,且不适用于广泛使用的模板。

推荐做法:

  • 类模板:将声明和实现统一在头文件中。
  • 函数模板:同样将声明和实现统一在头文件中,或使用显式实例化。

避免过度模板化

虽然模板提供了极大的灵活性,但过度复杂的模板会导致代码难以理解、维护和编译时间增加。

建议:

  • 只在必要时使用模板。
  • 保持模板的简单性和可读性,避免过度嵌套和复杂的特化。
  • 合理使用类型特性和 Concepts 进行约束。

提高编译速度的方法

模板的广泛使用可能导致编译时间显著增加。以下方法有助于优化编译速度:

  1. 预编译头文件(Precompiled Headers):将频繁使用的模板库放入预编译头中,加速编译。
  2. 显式实例化:通过显式实例化减少模板的重复编译。
  3. 模块化编程(C++20 Modules):利用模块化将模板库进行编译和链接,减少编译时间。
  4. 合理分割头文件:避免头文件中的模板定义过大,分割成较小的模块。

代码复用与库设计

模板是实现高度复用库组件的有效手段,如标准库(std::vector、std::map 等)广泛使用模板。设计模板库时,需考虑以下因素:

  • 接口的一致性:保持模板库的接口简洁、一致,便于使用者理解和使用。
  • 文档与示例:提供详细的文档和示例代码,帮助使用者理解模板库的用法。
  • 错误信息友好:通过 Concepts、SFINAE 等机制提供清晰的错误信息,降低使用门槛。
  • 性能优化:利用模板的编译期计算和内联等特性,提高库组件的性能。

避免模板错误的困惑

模板错误通常复杂且难以理解,以下方法有助于减少模板错误的困惑:

  • 逐步调试:从简单的模板开始,逐步增加复杂性,便于定位错误。
  • 使用编译器警告与工具:开启编译器的警告选项,使用静态分析工具检测模板代码中的问题。
  • 代码注释与文档:详细注释复杂的模板代码,提供文档说明其设计和用途。

总结

C++ 模板机制是实现泛型编程的核心工具,通过类型参数化和编译期计算,极大地提升了代码的复用性、灵活性和性能。从基础的函数模板和类模板,到高级的模板特化、变参模板、模板元编程、SFINAE 和 Concepts,掌握模板的各个方面能够帮助开发者编写更高效、更加通用的 C++ 代码。

在实际应用中,合理运用模板不仅可以简化代码结构,还可以提高代码的可维护性和可扩展性。然而,模板的复杂性也要求开发者具备扎实的 C++ 基础和良好的编程习惯,以避免过度复杂化和难以调试的问题。

通过本教案的系统学习,相信您已经具备了全面理解和运用 C++ 模板的能力,能够在实际项目中高效地利用模板特性,编写出更为优秀的代码。


练习与习题

练习 1:实现一个通用的 Swap 函数模板

要求:

  • 编写一个函数模板 swapValues,可以交换任意类型的两个变量。
  • 在 main 函数中测试 int、double、std::string 类型的交换。

提示:

1
2
3
4
5
6
template <typename T>
void swapValues(T& a, T& b) {
T temp = a;
a = b;
b = temp;
}

练习 2:实现一个模板类 Triple,存储三个相同类型的值,并提供获取各个成员的函数。

要求:

  • 模板参数为类型 T。
  • 提供构造函数、成员变量及访问函数。
  • 在 main 中实例化 Triple<int> 和 Triple<std::string>,进行测试。

练习 3:使用模板特化,为类模板 Printer 提供针对 bool 类型的全特化,实现专门的输出格式。

要求:

  • 通用模板类 Printer,具有 print 函数,输出 General Printer: value。
  • 全特化 Printer<bool>,输出 Boolean Printer: true 或 Boolean Printer: false。

练习 4:实现一个变参模板函数 logMessages,可以接受任意数量和类型的参数,并依次打印它们。

要求:

  • 使用递归方法实现。
  • 在 main 中测试不同参数组合的调用。

练习 5:编写模板元编程结构 IsPointer, 用于在编译期判断一个类型是否为指针类型。

要求:

  • 定义 IsPointer<T>,包含 value 静态常量成员,值为 true 或 false。
  • 使用特化进行实现。
  • 在 main 中使用 static_assert 进行测试。

示例:

1
2
static_assert(IsPointer<int*>::value, "int* is a pointer");
static_assert(!IsPointer<int>::value, "int is not a pointer");

练习 6:使用 SFINAE,编写一个函数模板 enableIfExample,只有当类型 T 具有 size() 成员函数时才启用。

要求:

  • 使用 std::enable_if 和类型特性检测 size() 成员。
  • 在 main 中测试 std::vector<int>(应启用)和 int(不应启用)。

提示:

1
2
3
4
5
template <typename T>
typename std::enable_if<has_size<T>::value, void>::type
enableIfExample(const T& container) {
std::cout << "Container has size: " << container.size() << std::endl;
}

练习 7:使用 C++20 Concepts,定义一个 Concept Integral,要求类型必须是整型,并使用该 Concept 约束一个函数模板 isEven,判断传入的整数是否为偶数。

要求:

  • 定义 Integral Concept。
  • 编写函数模板 isEven(u),仅接受满足 Integral 的类型。
  • 在 main 中测试不同类型的调用。

示例:

1
2
3
4
template <Integral T>
bool isEven(T value) {
return value % 2 == 0;
}

练习 8:实现一个固定大小的栈(FixedStack)类模板,支持多种数据类型和指定大小。使用非类型模板参数指定栈的大小。

要求:

  • 模板参数为类型 T 和 std::size_t N。
  • 提供 push, pop, top 等成员函数。
  • 在 main 中测试 FixedStack<int, 5> 和 FixedStack<std::string, 3>。

练习 9:实现一个模板类 TypeIdentity,其成员类型 type 等同于模板参数 T。并使用 static_assert 检查类型关系。

要求:

  • 定义 TypeIdentity<T>,包含类型成员 type。
  • 使用 std::is_same 与 static_assert 验证。

示例:

1
static_assert(std::is_same<TypeIdentity<int>::type, int>::value, "TypeIdentity<int> should be int");

练习 10:编写一个模板元编程结构 LengthOf, 用于在编译期计算类型列表的长度。

要求:

  • 使用 TypeList 模板定义类型列表。
  • 定义 LengthOf<TypeList<...>>::value 表示类型列表的长度。
  • 在 main 中使用 static_assert 进行测试。

提示:

1
2
3
4
5
6
7
8
9
10
template <typename... Ts>
struct TypeList {};

template <typename List>
struct LengthOf;

template <typename... Ts>
struct LengthOf<TypeList<Ts...>> {
static constexpr std::size_t value = sizeof...(Ts);
};

通过上述内容及练习,相信您已全面掌握了 C++ 模板的各个方面。从基础概念到高级技术,模板为 C++ 编程提供了强大的工具。持续练习与应用,将进一步巩固您的模板编程能力。

运算符重载详解

Posted on 2025-01-23 | In 零基础C++

运算符重载概述

运算符重载(Operator Overloading)允许开发者为自定义类型定义或重新定义运算符的行为,使得自定义类型的对象能够使用与内置类型相同的运算符进行操作。这不仅提高了代码的可读性,还增强了代码的表达能力。

为什么需要运算符重载

在面向对象编程中,我们经常需要定义自己的类来表示某些实体(如复数、向量、矩形等)。为了使这些类的对象能够与内置类型一样方便地进行操作,运算符重载显得尤为重要。例如:

  • 对于复数类,使用 + 运算符进行加法运算。
  • 对于字符串类,使用 << 运算符进行输出。
  • 对于矩阵类,使用 * 运算符进行矩阵乘法。

通过运算符重载,可以使代码更简洁、直观,类似于数学表达式。

运算符重载的规则与限制

  1. 不能改变运算符的优先级和结合性:运算符的优先级和结合性在编译阶段就确定,不能通过重载来改变。
  2. 不能创建新的运算符:仅能重载C++中已有的运算符,不能定义新的运算符。
  3. 至少有一个操作数必须是用户定义的类型:不能对两个内置类型进行运算符重载。
  4. 某些运算符不能重载:包括 .(成员选择运算符)、.*、::、?:(条件运算符)等。
  5. 重载运算符的优先级和结合性不可改变。

运算符重载的方法

在C++中,运算符可以通过成员函数或非成员函数(通常是友元函数)来重载。

成员函数方式

运算符作为类的成员函数进行重载时,左操作数是当前对象(this)。因此,对于需要修改左操作数的运算符,成员函数方式通常更直观。

语法示例:

1
2
3
4
class ClassName {
public:
ClassName operator+(const ClassName& other);
};

非成员函数方式(友元函数)

当需要对两个不同类型的对象进行运算,或者左操作数不是当前类的对象时,通常使用非成员函数方式。为了访问类的私有成员,非成员函数通常被声明为类的友元函数。

语法示例:

1
2
3
class ClassName {
friend ClassName operator+(const ClassName& lhs, const ClassName& rhs);
};

1. 算术运算符

1.1 + 运算符

作用:实现两个对象的加法操作。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>

class Complex {
private:
double real;
double imag;

public:
// 构造函数
Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) {}

// 重载 + 运算符(成员函数)
Complex operator+(const Complex& other) const {
return Complex(this->real + other.real, this->imag + other.imag);
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Complex& c);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Complex& c) {
os << "(" << c.real;
if (c.imag >= 0)
os << " + " << c.imag << "i)";
else
os << " - " << -c.imag << "i)";
return os;
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(1.5, -2.5);
Complex c3 = c1 + c2;
std::cout << "c1 + c2 = " << c3 << std::endl;
return 0;
}

输出:

1
c1 + c2 = (4.5 + 1.5i)

1.2 - 运算符

作用:实现两个对象的减法操作。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// ...(与上面类似)

// 重载 - 运算符(成员函数)
Complex operator-(const Complex& other) const {
return Complex(this->real - other.real, this->imag - other.imag);
}

// 示例
int main() {
Complex c1(5.0, 6.0);
Complex c2(2.5, -1.5);
Complex c4 = c1 - c2;
std::cout << "c1 - c2 = " << c4 << std::endl;
return 0;
}

输出:

1
c1 - c2 = (2.5 + 7.5i)

1.3 * 运算符

作用:实现对象的乘法操作。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// ...(与上面类似)

// 重载 * 运算符(成员函数)
Complex operator*(const Complex& other) const {
double r = this->real * other.real - this->imag * other.imag;
double i = this->real * other.imag + this->imag * other.real;
return Complex(r, i);
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(1.5, -2.5);
Complex c5 = c1 * c2;
std::cout << "c1 * c2 = " << c5 << std::endl;
return 0;
}

输出:

1
c1 * c2 = (13.5 + 1i)

1.4 / 运算符

作用:实现对象的除法操作。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <stdexcept>

// ...(与上面类似)

// 重载 / 运算符(成员函数)
Complex operator/(const Complex& other) const {
double denominator = other.real * other.real + other.imag * other.imag;
if (denominator == 0) {
throw std::invalid_argument("除数为零!");
}
double r = (this->real * other.real + this->imag * other.imag) / denominator;
double i = (this->imag * other.real - this->real * other.imag) / denominator;
return Complex(r, i);
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(1.5, -2.5);
try {
Complex c6 = c1 / c2;
std::cout << "c1 / c2 = " << c6 << std::endl;
} catch (const std::invalid_argument& e) {
std::cerr << "错误: " << e.what() << std::endl;
}
return 0;
}

输出:

1
c1 / c2 = (-0.823529 + 1.64706i)

2. 赋值运算符

2.1 = 运算符

作用:实现对象的赋值操作。

示例类:Complex(复数类)

C++编译器会自动生成默认的拷贝赋值运算符,但当类中包含动态分配内存或需要自定义行为时,需要自行重载。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <iostream>

class Complex {
private:
double real;
double imag;

public:
// 构造函数
Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) {}

// 拷贝赋值运算符(成员函数)
Complex& operator=(const Complex& other) {
if (this == &other)
return *this; // 防止自赋值
this->real = other.real;
this->imag = other.imag;
return *this;
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Complex& c);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Complex& c) {
os << "(" << c.real;
if (c.imag >= 0)
os << " + " << c.imag << "i)";
else
os << " - " << -c.imag << "i)";
return os;
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2;
c2 = c1; // 使用拷贝赋值运算符
std::cout << "c2 = " << c2 << std::endl;
return 0;
}

输出:

1
c2 = (3 + 4i)

2.2 复合赋值运算符(+=, -=, *=, /=)

作用:实现复合赋值操作,如 +=,-= 等。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// ...(与上面类似)

// 重载 += 运算符(成员函数)
Complex& operator+=(const Complex& other) {
this->real += other.real;
this->imag += other.imag;
return *this;
}

// 重载 -= 运算符(成员函数)
Complex& operator-=(const Complex& other) {
this->real -= other.real;
this->imag -= other.imag;
return *this;
}

// 重载 *= 运算符(成员函数)
Complex& operator*=(const Complex& other) {
double r = this->real * other.real - this->imag * other.imag;
double i = this->real * other.imag + this->imag * other.real;
this->real = r;
this->imag = i;
return *this;
}

// 重载 /= 运算符(成员函数)
Complex& operator/=(const Complex& other) {
double denominator = other.real * other.real + other.imag * other.imag;
if (denominator == 0) {
throw std::invalid_argument("除数为零!");
}
double r = (this->real * other.real + this->imag * other.imag) / denominator;
double i = (this->imag * other.real - this->real * other.imag) / denominator;
this->real = r;
this->imag = i;
return *this;
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(1.0, 2.0);

c1 += c2;
std::cout << "c1 += c2: " << c1 << std::endl;

c1 -= c2;
std::cout << "c1 -= c2: " << c1 << std::endl;

c1 *= c2;
std::cout << "c1 *= c2: " << c1 << std::endl;

try {
c1 /= c2;
std::cout << "c1 /= c2: " << c1 << std::endl;
} catch (const std::invalid_argument& e) {
std::cerr << "错误: " << e.what() << std::endl;
}

return 0;
}

输出:

1
2
3
4
c1 += c2: (4 + 6i)
c1 -= c2: (3 + 4i)
c1 *= c2: (-5 + 10i)
c1 /= c2: (2 + 0i)

3. 比较运算符

3.1 == 运算符

作用:判断两个对象是否相等。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// ...(与上面类似)

// 重载 == 运算符(友元函数)
friend bool operator==(const Complex& lhs, const Complex& rhs);

// 实现 == 运算符
bool operator==(const Complex& lhs, const Complex& rhs) {
return (lhs.real == rhs.real) && (lhs.imag == rhs.imag);
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(3.0, 4.0);
Complex c3(1.5, -2.5);

if (c1 == c2)
std::cout << "c1 和 c2 相等" << std::endl;
else
std::cout << "c1 和 c2 不相等" << std::endl;

if (c1 == c3)
std::cout << "c1 和 c3 相等" << std::endl;
else
std::cout << "c1 和 c3 不相等" << std::endl;

return 0;
}

输出:

1
2
c1 和 c2 相等
c1 和 c3 不相等

3.2 != 运算符

作用:判断两个对象是否不相等。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// ...(与上面类似)

// 重载 != 运算符(友元函数)
friend bool operator!=(const Complex& lhs, const Complex& rhs);

// 实现 != 运算符
bool operator!=(const Complex& lhs, const Complex& rhs) {
return !(lhs == rhs);
}

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(3.0, 4.0);
Complex c3(1.5, -2.5);

if (c1 != c2)
std::cout << "c1 和 c2 不相等" << std::endl;
else
std::cout << "c1 和 c2 相等" << std::endl;

if (c1 != c3)
std::cout << "c1 和 c3 不相等" << std::endl;
else
std::cout << "c1 和 c3 相等" << std::endl;

return 0;
}

输出:

1
2
c1 和 c2 相等
c1 和 c3 不相等

3.3 <, >, <=, >= 运算符

作用:实现对象之间的大小比较。对于复数来说,通常没有自然的大小顺序,但为了示例,可以定义复数的模长进行比较。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <cmath>

// ...(与上面类似)

// 重载 < 运算符(友元函数)
friend bool operator<(const Complex& lhs, const Complex& rhs);

// 重载 > 运算符(友元函数)
friend bool operator>(const Complex& lhs, const Complex& rhs);

// 重载 <= 运算符(友元函数)
friend bool operator<=(const Complex& lhs, const Complex& rhs);

// 重载 >= 运算符(友元函数)
friend bool operator>=(const Complex& lhs, const Complex& rhs);

// 实现 < 运算符
bool operator<(const Complex& lhs, const Complex& rhs) {
double lhs_mod = std::sqrt(lhs.real * lhs.real + lhs.imag * lhs.imag);
double rhs_mod = std::sqrt(rhs.real * rhs.real + rhs.imag * rhs.imag);
return lhs_mod < rhs_mod;
}

// 实现 > 运算符
bool operator>(const Complex& lhs, const Complex& rhs) {
return rhs < lhs;
}

// 实现 <= 运算符
bool operator<=(const Complex& lhs, const Complex& rhs) {
return !(rhs < lhs);
}

// 实现 >= 运算符
bool operator>=(const Complex& lhs, const Complex& rhs) {
return !(lhs < rhs);
}

// 示例
int main() {
Complex c1(3.0, 4.0); // 模长 5
Complex c2(1.0, 2.0); // 模长 sqrt(5) ≈ 2.236
Complex c3(3.0, 4.0); // 模长 5

if (c1 < c2)
std::cout << "c1 的模长小于 c2 的模长" << std::endl;
else
std::cout << "c1 的模长不小于 c2 的模长" << std::endl;

if (c1 > c2)
std::cout << "c1 的模长大于 c2 的模长" << std::endl;
else
std::cout << "c1 的模长不大于 c2 的模长" << std::endl;

if (c1 <= c3)
std::cout << "c1 的模长小于或等于 c3 的模长" << std::endl;
else
std::cout << "c1 的模长大于 c3 的模长" << std::endl;

if (c1 >= c3)
std::cout << "c1 的模长大于或等于 c3 的模长" << std::endl;
else
std::cout << "c1 的模长小于 c3 的模长" << std::endl;

return 0;
}

输出:

1
2
3
4
c1 的模长不小于 c2 的模长
c1 的模长大于 c2 的模长
c1 的模长小于或等于 c3 的模长
c1 的模长大于或等于 c3 的模长

4. 逻辑运算符

4.1 &&, ||, ! 运算符

作用:实现逻辑操作。需要注意,C++ 中的 && 和 || 运算符无法短路地重载,而且通常不建议重载它们,因为会改变其原有的逻辑语义。通常,建议使用类型转换或其他方法来实现逻辑判断。

示例类:Boolean 类(用于示例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <iostream>

class Boolean {
private:
bool value;

public:
// 构造函数
Boolean(bool val = false) : value(val) {}

// 重载逻辑非运算符(!)(成员函数)
bool operator!() const {
return !value;
}

// 重载逻辑与运算符(&)(非短路)(成员函数)
Boolean operator&(const Boolean& other) const {
return Boolean(this->value & other.value);
}

// 重载逻辑或运算符(|)(非短路)(成员函数)
Boolean operator|(const Boolean& other) const {
return Boolean(this->value | other.value);
}

// 重载输出运算符
friend std::ostream& operator<<(std::ostream& os, const Boolean& b);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Boolean& b) {
os << (b.value ? "true" : "false");
return os;
}

// 示例
int main() {
Boolean b1(true);
Boolean b2(false);

Boolean b3 = b1 & b2;
Boolean b4 = b1 | b2;
Boolean b5 = !b1;
Boolean b6 = !b2;

std::cout << "b1 & b2 = " << b3 << std::endl;
std::cout << "b1 | b2 = " << b4 << std::endl;
std::cout << "!b1 = " << b5 << std::endl;
std::cout << "!b2 = " << b6 << std::endl;

return 0;
}

输出:

1
2
3
4
b1 & b2 = false
b1 | b2 = true
!b1 = false
!b2 = true

说明:

  • 注意:在重载 && 和 || 运算符时,要明白它们不会具有短路行为。因此,通常不建议重载这两个运算符。
  • 本例中,使用 & 和 | 运算符来模拟逻辑与、或操作。

5. 位运算符

5.1 &, |, ^, ~ 运算符

作用:实现位级操作,如按位与、按位或、按位异或、按位取反。

示例类:Bitmask 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <iostream>

class Bitmask {
private:
unsigned int bits;

public:
// 构造函数
Bitmask(unsigned int b = 0) : bits(b) {}

// 重载 & 运算符(成员函数)
Bitmask operator&(const Bitmask& other) const {
return Bitmask(this->bits & other.bits);
}

// 重载 | 运算符(成员函数)
Bitmask operator|(const Bitmask& other) const {
return Bitmask(this->bits | other.bits);
}

// 重载 ^ 运算符(成员函数)
Bitmask operator^(const Bitmask& other) const {
return Bitmask(this->bits ^ other.bits);
}

// 重载 ~ 运算符(成员函数)
Bitmask operator~() const {
return Bitmask(~this->bits);
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Bitmask& b);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Bitmask& b) {
os << "0x" << std::hex << b.bits << std::dec;
return os;
}

// 示例
int main() {
Bitmask bm1(0b10101010); // 0xAA
Bitmask bm2(0b11001100); // 0xCC

Bitmask bm3 = bm1 & bm2;
Bitmask bm4 = bm1 | bm2;
Bitmask bm5 = bm1 ^ bm2;
Bitmask bm6 = ~bm1;

std::cout << "bm1 & bm2 = " << bm3 << std::endl;
std::cout << "bm1 | bm2 = " << bm4 << std::endl;
std::cout << "bm1 ^ bm2 = " << bm5 << std::endl;
std::cout << "~bm1 = " << bm6 << std::endl;

return 0;
}

输出:

1
2
3
4
bm1 & bm2 = 0x88
bm1 | bm2 = 0xee
bm1 ^ bm2 = 0x66
~bm1 = 0xffffff55

5.2 <<, >> 位移运算符

作用:实现位移操作,如左移、右移。

示例类:Bitmask 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// ...(与上面类似)

// 重载 << 运算符(左移,成员函数)
Bitmask operator<<(int shift) const {
return Bitmask(this->bits << shift);
}

// 重载 >> 运算符(右移,成员函数)
Bitmask operator>>(int shift) const {
return Bitmask(this->bits >> shift);
}

// 示例
int main() {
Bitmask bm1(0b0001); // 0x1

Bitmask bm2 = bm1 << 3;
Bitmask bm3 = bm1 >> 1;

std::cout << "bm1 << 3 = " << bm2 << std::endl;
std::cout << "bm1 >> 1 = " << bm3 << std::endl;

return 0;
}

输出:

1
2
bm1 << 3 = 0x8
bm1 >> 1 = 0x0

说明:

  • 重载位移运算符时,通常接受一个整型参数,表示位移的位数。

6. 自增自减运算符

6.1 前置 ++ 和 -- 运算符

作用:实现对象的自增和自减操作。

示例类:Counter 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <iostream>

class Counter {
private:
int count;

public:
// 构造函数
Counter(int c = 0) : count(c) {}

// 前置 ++ 运算符(成员函数)
Counter& operator++() {
++count;
return *this;
}

// 前置 -- 运算符(成员函数)
Counter& operator--() {
--count;
return *this;
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Counter& c);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Counter& c) {
os << c.count;
return os;
}

// 示例
int main() {
Counter c(10);
std::cout << "初始值: " << c << std::endl;
std::cout << "++c = " << ++c << std::endl;
std::cout << "--c = " << --c << std::endl;
return 0;
}

输出:

1
2
3
初始值: 10
++c = 11
--c = 10

6.2 后置 ++ 和 -- 运算符

作用:实现对象的后置自增和自减操作。

示例类:Counter 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// ...(与上面类似)

// 后置 ++ 运算符(成员函数)
Counter operator++(int) {
Counter temp = *this;
++count;
return temp;
}

// 后置 -- 运算符(成员函数)
Counter operator--(int) {
Counter temp = *this;
--count;
return temp;
}

// 示例
int main() {
Counter c(10);
std::cout << "初始值: " << c << std::endl;
std::cout << "c++ = " << c++ << std::endl;
std::cout << "c-- = " << c-- << std::endl;
std::cout << "当前值: " << c << std::endl;
return 0;
}

输出:

1
2
3
4
初始值: 10
c++ = 10
c-- = 11
当前值: 10

说明:

  • 前置运算符:先修改对象,再返回引用。
  • 后置运算符:先保存原值,修改对象,再返回原值。

7. 下标运算符 []

作用:实现对象的下标访问,如数组访问。

示例类:Vector 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <iostream>
#include <vector>
#include <stdexcept>

class Vector {
private:
std::vector<double> components;

public:
// 构造函数
Vector(const std::vector<double>& comps) : components(comps) {}

// 重载 [] 运算符(非 const)
double& operator[](size_t index) {
if (index >= components.size()) {
throw std::out_of_range("下标越界!");
}
return components[index];
}

// 重载 [] 运算符(const)
const double& operator[](size_t index) const {
if (index >= components.size()) {
throw std::out_of_range("下标越界!");
}
return components[index];
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Vector& v);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Vector& v) {
os << "(";
for (size_t i = 0; i < v.components.size(); ++i) {
os << v.components[i];
if (i != v.components.size() - 1)
os << ", ";
}
os << ")";
return os;
}

// 示例
int main() {
Vector v({1.0, 2.0, 3.0});
std::cout << "初始向量: " << v << std::endl;

// 访问元素
try {
std::cout << "v[1] = " << v[1] << std::endl;
v[1] = 5.0;
std::cout << "修改后向量: " << v << std::endl;
} catch (const std::out_of_range& e) {
std::cerr << "错误: " << e.what() << std::endl;
}

// 访问越界
try {
std::cout << "v[3] = " << v[3] << std::endl; // 越界
} catch (const std::out_of_range& e) {
std::cerr << "错误: " << e.what() << std::endl;
}

return 0;
}

输出:

1
2
3
4
初始向量: (1, 2, 3)
v[1] = 2
修改后向量: (1, 5, 3)
错误: 下标越界!

说明:

  • 提供了 const 和 非 const 两种重载,以支持不同上下文中的访问。
  • 在访问时进行了边界检查,确保安全性。

8. 函数调用运算符 ()

作用:使对象能够像函数一样被调用,常用于函数对象(functors)或仿函数。

示例类:Multiplier 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <iostream>

class Multiplier {
private:
double factor;

public:
// 构造函数
Multiplier(double f = 1.0) : factor(f) {}

// 重载 () 运算符(成员函数)
double operator()(double x) const {
return x * factor;
}
};

// 示例
int main() {
Multiplier double_it(2.0);
Multiplier triple_it(3.0);

std::cout << "double_it(5) = " << double_it(5) << std::endl;
std::cout << "triple_it(5) = " << triple_it(5) << std::endl;

return 0;
}

输出:

1
2
double_it(5) = 10
triple_it(5) = 15

说明:

  • 通过重载 () 运算符,Multiplier 对象可以像函数一样接受参数并进行操作。
  • 常用于需要定制函数行为的场景,如排序时的比较函数。

9. 输入输出运算符 <<, >>

作用:实现对象与输入输出流之间的交互。

示例类:Complex(复数类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <iostream>
#include <stdexcept>

class Complex {
private:
double real;
double imag;

public:
// 构造函数
Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) {}

// 重载 << 运算符(友元函数)
friend std::ostream& operator<<(std::ostream& os, const Complex& c);

// 重载 >> 运算符(友元函数)
friend std::istream& operator>>(std::istream& is, Complex& c);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Complex& c) {
os << "(" << c.real;
if (c.imag >= 0)
os << " + " << c.imag << "i)";
else
os << " - " << -c.imag << "i)";
return os;
}

// 实现 >> 运算符
std::istream& operator>>(std::istream& is, Complex& c) {
// 假设输入格式为:real imag
is >> c.real >> c.imag;
return is;
}

// 示例
int main() {
Complex c1;
std::cout << "请输入复数的实部和虚部,以空格分隔: ";
std::cin >> c1;
std::cout << "您输入的复数是: " << c1 << std::endl;
return 0;
}

示例输入:

1
请输入复数的实部和虚部,以空格分隔: 3.5 -2.1

输出:

1
您输入的复数是: (3.5 - 2.1i)

说明:

  • << 运算符用于输出对象到流中。
  • >> 运算符用于从流中输入对象的数据。
  • 一般将这些运算符重载为友元函数,以便访问类的私有成员。

10. 其他运算符

10.1 成员访问运算符 ->, ->*

说明:

  • 运算符 -> 和 ->* 通常用于代理模式或智能指针的实现,较为复杂。
  • 其重载需要返回一个指针类型,以便进一步访问成员。
  • 通常不建议普通类进行重载,除非有特定需求。

示例类:Proxy 类(代理模式)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <iostream>

class RealObject {
public:
void display() const {
std::cout << "RealObject::display()" << std::endl;
}
};

class Proxy {
private:
RealObject* ptr;

public:
// 构造函数
Proxy(RealObject* p = nullptr) : ptr(p) {}

// 重载 -> 运算符(成员函数)
RealObject* operator->() const {
return ptr;
}
};

// 示例
int main() {
RealObject real;
Proxy proxy(&real);

proxy->display(); // 使用重载的 -> 运算符
return 0;
}

输出:

1
RealObject::display()

说明:

  • Proxy 类通过重载 -> 运算符,将对 Proxy 对象的成员访问转发给其内部的 RealObject 对象。
  • 这是实现代理模式或智能指针的常见方式。

综合案例:复数(Complex)类中的所有运算符重载

为了将上述所有运算符的重载整合在一个类中,以下是一个全面的 Complex 类示例,涵盖了大部分可重载的运算符。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#include <iostream>
#include <cmath>
#include <stdexcept>

class Complex {
private:
double real;
double imag;

public:
// 构造函数
Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) {}

// 拷贝赋值运算符
Complex& operator=(const Complex& other) {
if (this == &other)
return *this;
this->real = other.real;
this->imag = other.imag;
return *this;
}

// 重载 + 运算符(成员函数)
Complex operator+(const Complex& other) const {
return Complex(this->real + other.real, this->imag + other.imag);
}

// 重载 - 运算符(成员函数)
Complex operator-(const Complex& other) const {
return Complex(this->real - other.real, this->imag - other.imag);
}

// 重载 * 运算符(成员函数)
Complex operator*(const Complex& other) const {
double r = this->real * other.real - this->imag * other.imag;
double i = this->real * other.imag + this->imag * other.real;
return Complex(r, i);
}

// 重载 / 运算符(成员函数)
Complex operator/(const Complex& other) const {
double denominator = other.real * other.real + other.imag * other.imag;
if (denominator == 0) {
throw std::invalid_argument("除数为零!");
}
double r = (this->real * other.real + this->imag * other.imag) / denominator;
double i = (this->imag * other.real - this->real * other.imag) / denominator;
return Complex(r, i);
}

// 重载 += 运算符
Complex& operator+=(const Complex& other) {
this->real += other.real;
this->imag += other.imag;
return *this;
}

// 重载 -= 运算符
Complex& operator-=(const Complex& other) {
this->real -= other.real;
this->imag -= other.imag;
return *this;
}

// 重载 *= 运算符
Complex& operator*=(const Complex& other) {
double r = this->real * other.real - this->imag * other.imag;
double i = this->real * other.imag + this->imag * other.real;
this->real = r;
this->imag = i;
return *this;
}

// 重载 /= 运算符
Complex& operator/=(const Complex& other) {
double denominator = other.real * other.real + other.imag * other.imag;
if (denominator == 0) {
throw std::invalid_argument("除数为零!");
}
double r = (this->real * other.real + this->imag * other.imag) / denominator;
double i = (this->imag * other.real - this->real * other.imag) / denominator;
this->real = r;
this->imag = i;
return *this;
}

// 重载 == 运算符
friend bool operator==(const Complex& lhs, const Complex& rhs) {
return (lhs.real == rhs.real) && (lhs.imag == rhs.imag);
}

// 重载 != 运算符
friend bool operator!=(const Complex& lhs, const Complex& rhs) {
return !(lhs == rhs);
}

// 重载 < 运算符(基于模长)
friend bool operator<(const Complex& lhs, const Complex& rhs) {
double lhs_mod = std::sqrt(lhs.real * lhs.real + lhs.imag * lhs.imag);
double rhs_mod = std::sqrt(rhs.real * rhs.real + rhs.imag * rhs.imag);
return lhs_mod < rhs_mod;
}

// 重载 > 运算符
friend bool operator>(const Complex& lhs, const Complex& rhs) {
return rhs < lhs;
}

// 重载 <= 运算符
friend bool operator<=(const Complex& lhs, const Complex& rhs) {
return !(rhs < lhs);
}

// 重载 >= 运算符
friend bool operator>=(const Complex& lhs, const Complex& rhs) {
return !(lhs < rhs);
}

// 重载 << 运算符
friend std::ostream& operator<<(std::ostream& os, const Complex& c) {
os << "(" << c.real;
if (c.imag >= 0)
os << " + " << c.imag << "i)";
else
os << " - " << -c.imag << "i)";
return os;
}

// 重载 >> 运算符
friend std::istream& operator>>(std::istream& is, Complex& c) {
// 简单输入格式:real imag
is >> c.real >> c.imag;
return is;
}

// 重载 ~ 运算符(取反复数)
Complex operator~() const {
return Complex(this->real, -this->imag);
}

// 重载逻辑非运算符(!)
bool operator!() const {
return (this->real == 0 && this->imag == 0);
}

// 重载下标运算符(如 c[0] 返回 real, c[1] 返回 imag)
double& operator[](size_t index) {
if (index == 0)
return real;
else if (index == 1)
return imag;
else
throw std::out_of_range("下标越界!");
}

const double& operator[](size_t index) const {
if (index == 0)
return real;
else if (index == 1)
return imag;
else
throw std::out_of_range("下标越界!");
}

// 重载前置 ++ 运算符
Complex& operator++() {
++real;
++imag;
return *this;
}

// 重载后置 ++ 运符
Complex operator++(int) {
Complex temp = *this;
++real;
++imag;
return temp;
}

// 重载前置 -- 运算符
Complex& operator--() {
--real;
--imag;
return *this;
}

// 重载后置 -- 运算符
Complex operator--(int) {
Complex temp = *this;
--real;
--imag;
return temp;
}

// 重载函数调用运算符
double operator()(const std::string& part) const {
if (part == "real")
return real;
else if (part == "imag")
return imag;
else
throw std::invalid_argument("参数错误!");
}
};

// 示例
int main() {
Complex c1(3.0, 4.0);
Complex c2(1.5, -2.5);
Complex c3;

// 赋值运算
c3 = c1;
std::cout << "c3 = " << c3 << std::endl;

// 加法
Complex c4 = c1 + c2;
std::cout << "c1 + c2 = " << c4 << std::endl;

// 减法
Complex c5 = c1 - c2;
std::cout << "c1 - c2 = " << c5 << std::endl;

// 乘法
Complex c6 = c1 * c2;
std::cout << "c1 * c2 = " << c6 << std::endl;

// 除法
try {
Complex c7 = c1 / c2;
std::cout << "c1 / c2 = " << c7 << std::endl;
} catch (const std::invalid_argument& e) {
std::cerr << "错误: " << e.what() << std::endl;
}

// 比较运算
if (c1 > c2)
std::cout << "c1 的模长大于 c2" << std::endl;
else
std::cout << "c1 的模长不大于 c2" << std::endl;

// 逻辑非运算
Complex c_zero;
if (!c1)
std::cout << "c1 是零复数" << std::endl;
else
std::cout << "c1 不是零复数" << std::endl;

if (!c_zero)
std::cout << "c_zero 是零复数" << std::endl;
else
std::cout << "c_zero 不是零复数" << std::endl;

// 取反运算
Complex c_neg = ~c1;
std::cout << "~c1 = " << c_neg << std::endl;

// 下标运算
try {
std::cout << "c1[0] (real) = " << c1[0] << std::endl;
std::cout << "c1[1] (imag) = " << c1[1] << std::endl;
// std::cout << "c1[2] = " << c1[2] << std::endl; // 会抛出异常
} catch (const std::out_of_range& e) {
std::cerr << "错误: " << e.what() << std::endl;
}

// 自增自减运算符
std::cout << "c1 = " << c1 << std::endl;
std::cout << "++c1 = " << ++c1 << std::endl;
std::cout << "c1++ = " << c1++ << std::endl;
std::cout << "c1 = " << c1 << std::endl;
std::cout << "--c1 = " << --c1 << std::endl;
std::cout << "c1-- = " << c1-- << std::endl;
std::cout << "c1 = " << c1 << std::endl;

// 函数调用运算符
std::cout << "c1 的实部: " << c1("real") << std::endl;
std::cout << "c1 的虚部: " << c1("imag") << std::endl;

// 输入运算
std::cout << "请输入一个复数的实部和虚部,以空格分隔: ";
Complex c_input;
std::cin >> c_input;
std::cout << "您输入的复数是: " << c_input << std::endl;

return 0;
}

示例运行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
c3 = (3 + 4i)
c1 + c2 = (4.5 + 1.5i)
c1 - c2 = (1.5 + 6.5i)
c1 * c2 = (13.5 + 1i)
c1 / c2 = (-0.823529 + 1.64706i)
c1 的模长大于 c2
c1 不是零复数
c_zero 是零复数
~c1 = (3 - 4i)
c1[0] (real) = 3
c1[1] (imag) = 4
c1 = (3 + 4i)
++c1 = (4 + 5i)
c1++ = (4 + 5i)
c1 = (5 + 6i)
--c1 = (4 + 5i)
c1-- = (4 + 5i)
c1 = (3 + 4i)
c1 的实部: 3
c1 的虚部: 4
请输入一个复数的实部和虚部,以空格分隔: 2.5 -3.5
您输入的复数是: (2.5 - 3.5i)

说明:

  • 该类集成了大部分可重载的运算符,包括算术、赋值、比较、逻辑、位运算、自增自减、下标、函数调用以及输入输出运算符。
  • 某些运算符(如 &&, ||, ->*)未在此示例中体现,因为它们的重载较为复杂且不常见。
  • 在实际开发中,应根据需求选择性地重载运算符,避免过度设计。

11. 其他可重载运算符

11.1 逗号运算符 ,

作用:实现对象在逗号表达式中的行为。

示例类:Logger 类(用于示例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <iostream>
#include <string>

class Logger {
private:
std::string message;

public:
// 构造函数
Logger(const std::string& msg = "") : message(msg) {}

// 重载逗号运算符(成员函数)
Logger operator,(const Logger& other) const {
// 简单示例:连接日志消息
return Logger(this->message + ", " + other.message);
}

// 重载 << 运算符用于输出
friend std::ostream& operator<<(std::ostream& os, const Logger& l);
};

// 实现 << 运算符
std::ostream& operator<<(std::ostream& os, const Logger& l) {
os << l.message;
return os;
}

// 示例
int main() {
Logger log1("启动");
Logger log2("加载配置");
Logger log3("初始化");
Logger combined = (log1, log2, log3);
std::cout << "组合日志: " << combined << std::endl;
return 0;
}

输出:

1
组合日志: 启动, 加载配置, 初始化

说明:

  • 重载 , 运算符可以自定义逗号表达式的行为,但在实际应用中不常见,应谨慎使用。
  • 多个逗号运算符的重载会按从左至右的顺序依次调用。

运算符重载注意事项

  1. 语义一致性:重载运算符后,其行为应与运算符的传统意义保持一致。例如,+ 应表示加法,避免引起混淆。
  2. 效率:尽量避免不必要的对象拷贝,可以通过返回引用或使用移动语义提升效率。
  3. 异常安全:在实现运算符重载时,考虑并处理可能的异常情况,确保程序的健壮性。
  4. 封装性:保持类的封装性,避免过度暴露内部细节。仅在必要时使用友元函数。
  5. 返回类型:根据运算符的用途选择合适的返回类型。例如,算术运算符通常返回新对象,赋值运算符返回引用等。
  6. 避免复杂的逻辑:运算符重载应简洁明了,不应包含过于复杂的逻辑,避免使代码难以理解和维护。
  7. 可读性:使用适当的注释和文档说明运算符重载的行为,增强代码的可读性。

小结

运算符重载是C++中强大的特性,允许开发者为自定义类定义或重新定义运算符的行为,使对象的操作更加直观和符合逻辑。在设计和实现运算符重载时,应遵循语义一致性、效率和封装性等原则,避免滥用。通过本教案中的详细案例,学习者可以全面理解运算符重载的应用,并在实际编程中灵活运用。

unorderedmap以及手写无序map

Posted on 2025-01-14 | In 零基础C++

unordermap用法

unordered_map 是 C++ 标准库中的关联容器,提供了基于哈希表的键值对存储结构。与 map (基于红黑树实现)不同,unordered_map 提供的是平均常数时间复杂度的查找、插入和删除操作,但不保证元素的顺序。

以下是 unordered_map 的详细用法说明:

1. 头文件

要使用 unordered_map,需要包含头文件:

1
#include <unordered_map>

2. 基本定义

unordered_map 的基本模板定义如下:

1
std::unordered_map<KeyType, ValueType, Hash = std::hash<KeyType>, KeyEqual = std::equal_to<KeyType>, Allocator = std::allocator<std::pair<const KeyType, ValueType>>>

常用模板参数:

  • KeyType:键的类型,需要支持哈希运算和相等比较。
  • ValueType:值的类型。
  • Hash:哈希函数,默认为 std::hash<KeyType>。
  • KeyEqual:键相等的比较函数,默认为 std::equal_to<KeyType>。
  • Allocator:内存分配器,默认为 std::allocator。

3. 常用操作

3.1 创建和初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <iostream>
#include <unordered_map>
#include <string>

int main() {
// 创建一个空的 unordered_map,键为 string,值为 int
std::unordered_map<std::string, int> umap;

// 使用初始化列表初始化
std::unordered_map<std::string, int> umap_init = {
{"apple", 3},
{"banana", 5},
{"orange", 2}
};

// 使用其他容器范围初始化
std::map<std::string, int> ordered_map = {{"carrot", 4}, {"lettuce", 1}};
std::unordered_map<std::string, int> umap_from_map(ordered_map.begin(), ordered_map.end());

return 0;
}

3.2 插入元素

1
2
3
4
5
6
7
8
// 方法1:使用下标操作符
umap["grape"] = 7;

// 方法2:使用 insert
umap.insert({"melon", 6});

// 方法3:使用 emplace,直接在容器内部构造元素
umap.emplace("kiwi", 4);

3.3 访问元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 使用下标操作符访问或插入
int apple_count = umap["apple"]; // 如果 "apple" 不存在,会插入一个默认值

// 使用 at() 方法访问,不存在时会抛出异常
try {
int banana_count = umap.at("banana");
} catch (const std::out_of_range& e) {
std::cerr << "Key not found." << std::endl;
}

// 使用 find() 方法查找
auto it = umap.find("orange");
if (it != umap.end()) {
std::cout << "Orange count: " << it->second << std::endl;
} else {
std::cout << "Orange not found." << std::endl;
}

3.4 删除元素

1
2
3
4
5
6
7
8
9
10
11
// 根据键删除
umap.erase("grape");

// 根据迭代器删除
auto it = umap.find("banana");
if (it != umap.end()) {
umap.erase(it);
}

// 清空整个容器
umap.clear();

3.5 遍历元素

1
2
3
4
5
6
7
8
for (const auto& pair : umap) {
std::cout << pair.first << ": " << pair.second << std::endl;
}

// 使用迭代器
for (auto it = umap.begin(); it != umap.end(); ++it) {
std::cout << it->first << ": " << it->second << std::endl;
}

3.6 其他常用方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 获取大小
size_t size = umap.size();

// 检查是否为空
bool is_empty = umap.empty();

// 获取桶的数量(用于哈希表内部结构)
size_t bucket_count = umap.bucket_count();

// 重新哈希,调整桶的数量
umap.rehash(20);

// 从一个容器中移交元素到另一个容器
std::unordered_map<std::string, int> umap2 = std::move(umap);

4. 性能优化

4.1 预分配桶数

如果预先知道元素的大致数量,可以通过 reserve 预分配内存,以减少哈希表的重哈希开销:

1
umap.reserve(100); // 预分配足够容纳100个元素的桶

4.2 自定义哈希函数

如果键的类型是自定义类型或需要特殊的哈希策略,可以自定义哈希函数。例如,自定义结构体作为键:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
struct Point {
int x;
int y;

bool operator==(const Point& other) const {
return x == other.x && y == other.y;
}
};

// 自定义哈希函数
struct PointHash {
std::size_t operator()(const Point& p) const {
return std::hash<int>()(p.x) ^ (std::hash<int>()(p.y) << 1);
}
};

// 定义 unordered_map 使用自定义哈希函数
std::unordered_map<Point, std::string, PointHash> point_map;
point_map[{1, 2}] = "A";
point_map[{3, 4}] = "B";

4.3 自定义键相等比较

如果需要自定义键的比较逻辑,可以提供自定义的 KeyEqual 函数对象:

1
2
3
4
5
6
7
8
struct PointEqual {
bool operator()(const Point& a, const Point& b) const {
return (a.x == b.x) && (a.y == b.y);
}
};

// 定义 unordered_map 使用自定义哈希和比较函数
std::unordered_map<Point, std::string, PointHash, PointEqual> point_map;

5. 与 map 的比较

  • 底层实现:unordered_map 基于哈希表,实现的操作平均时间复杂度为常数级别;map 基于红黑树,实现的查找、插入、删除操作时间复杂度为对数级别。
  • 元素顺序:unordered_map 不保证元素的顺序;map 按键的顺序(通常是升序)存储元素。
  • 适用场景:当需要快速查找、插入和删除,且不关心元素顺序时,选择 unordered_map;当需要有序存储或按顺序遍历时,选择 map。

6. 完整示例

以下是一个使用 unordered_map 的完整示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <iostream>
#include <unordered_map>
#include <string>

// 自定义类型
struct Person {
std::string name;
int age;

bool operator==(const Person& other) const {
return name == other.name && age == other.age;
}
};

// 自定义哈希函数
struct PersonHash {
std::size_t operator()(const Person& p) const {
return std::hash<std::string>()(p.name) ^ (std::hash<int>()(p.age) << 1);
}
};

int main() {
// 创建一个 unordered_map,键为 string,值为 int
std::unordered_map<std::string, int> fruit_count;
fruit_count["apple"] = 5;
fruit_count["banana"] = 3;
fruit_count.emplace("orange", 2);

// 访问元素
std::cout << "Apple count: " << fruit_count["apple"] << std::endl;

// 遍历
for (const auto& pair : fruit_count) {
std::cout << pair.first << ": " << pair.second << std::endl;
}

// 使用自定义类型作为键
std::unordered_map<Person, std::string, PersonHash> person_map;
Person p1{"Alice", 30};
Person p2{"Bob", 25};
person_map[p1] = "Engineer";
person_map.emplace(p2, "Designer");

// 访问自定义类型键的值
Person p3{"Alice", 30};
std::cout << "Alice's job: " << person_map[p3] << std::endl;

return 0;
}

输出示例:

1
2
3
4
5
Apple count: 5
banana: 3
orange: 2
apple: 5
Alice's job: Engineer

手写unordermap

1. 哈希表的基本原理

哈希表是一种基于键值对的数据结构,通过哈希函数(Hash Function)将键映射到表中的一个索引位置,以实现快速的数据访问。哈希表的关键特性包括:

  • 哈希函数:将键映射到表中一个特定的桶(Bucket)或槽(Slot)。
  • 冲突解决:当不同的键通过哈希函数映射到同一个桶时,需要一种机制来处理这些冲突。常见的方法有链地址法(Separate Chaining)和开放地址法(Open Addressing)。
  • 负载因子(Load Factor):表示表中已存储元素的数量与表大小之间的比率。高负载因子可能导致更多的冲突,需要通过扩容来维持性能。

在本实现中,我们将采用链地址法来处理哈希冲突,即每个桶存储一个链表(或其他动态数据结构)来存储具有相同哈希值的元素。

2. 数据结构设计

HashNode 结构

HashNode 用于存储键值对及相关的指针,以构建链表。每个 HashNode 包含键、值和指向下一个节点的指针。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <vector>
#include <list>
#include <utility> // For std::pair
#include <functional> // For std::hash
#include <iterator> // For iterator_traits
#include <stdexcept> // For exceptions

// HashNode 结构定义
template <typename Key, typename T>
struct HashNode {
std::pair<const Key, T> data;
HashNode* next;

HashNode(const Key& key, const T& value)
: data(std::make_pair(key, value)), next(nullptr) {}
};

MyHashMap 类定义

MyHashMap 是我们自定义的哈希表实现,支持基本的 Map 操作和迭代器功能。它使用一个向量(std::vector)来存储桶,每个桶是一个链表,用于处理冲突。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
template <typename Key, typename T, typename Hash = std::hash<Key>>
class MyHashMap {
public:
// 迭代器类前向声明
class Iterator;

// 类型定义
using key_type = Key;
using mapped_type = T;
using value_type = std::pair<const Key, T>;
using size_type = size_t;

// 构造函数及析构函数
MyHashMap(size_type initial_capacity = 16, double max_load_factor = 0.75);
~MyHashMap();

// 禁止拷贝构造和赋值
MyHashMap(const MyHashMap&) = delete;
MyHashMap& operator=(const MyHashMap&) = delete;

// 基本操作
void insert(const Key& key, const T& value);
T* find(const Key& key);
const T* find(const Key& key) const;
bool erase(const Key& key);
size_type size() const;
bool empty() const;
void clear();

// 迭代器操作
Iterator begin();
Iterator end();

// 迭代器类
class Iterator {
public:
// 迭代器别名
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, T>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;

// 构造函数
Iterator(MyHashMap* map, size_type bucket_index, HashNode<Key, T>* node);

// 解引用操作符
reference operator*() const;
pointer operator->() const;

// 递增操作符
Iterator& operator++();
Iterator operator++(int);

// 比较操作符
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

private:
MyHashMap* map_;
size_type bucket_index_;
HashNode<Key, T>* current_node_;

// 移动到下一个有效节点
void advance();
};

private:
std::vector<HashNode<Key, T>*> buckets_;
size_type bucket_count_;
size_type element_count_;
double max_load_factor_;
Hash hash_func_;

// 辅助函数
void rehash();
};

3. 基本操作实现

构造函数及析构函数

初始化哈希表,设置初始容量和负载因子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 构造函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::MyHashMap(size_type initial_capacity, double max_load_factor)
: bucket_count_(initial_capacity),
element_count_(0),
max_load_factor_(max_load_factor),
hash_func_(Hash()) {
buckets_.resize(bucket_count_, nullptr);
}

// 析构函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::~MyHashMap() {
clear();
}

插入(Insert)

向哈希表中插入键值对。如果键已存在,则更新其值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::insert(const Key& key, const T& value) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
node->data.second = value; // 更新值
return;
}
node = node->next;
}

// 键不存在,插入新节点到链表头部
HashNode<Key, T>* new_node = new HashNode<Key, T>(key, value);
new_node->next = buckets_[index];
buckets_[index] = new_node;
++element_count_;

// 检查负载因子,可能需要扩容
double load_factor = static_cast<double>(element_count_) / bucket_count_;
if (load_factor > max_load_factor_) {
rehash();
}
}

查找(Find)

根据键查找对应的值,返回指向值的指针。如果未找到,则返回 nullptr。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
template <typename Key, typename T, typename Hash>
T* MyHashMap<Key, T, Hash>::find(const Key& key) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
return &(node->data.second);
}
node = node->next;
}
return nullptr;
}

template <typename Key, typename T, typename Hash>
const T* MyHashMap<Key, T, Hash>::find(const Key& key) const {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
return &(node->data.second);
}
node = node->next;
}
return nullptr;
}

删除(Erase)

根据键删除对应的键值对,返回删除是否成功。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::erase(const Key& key) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
HashNode<Key, T>* prev = nullptr;

while (node != nullptr) {
if (node->data.first == key) {
if (prev == nullptr) {
buckets_[index] = node->next;
} else {
prev->next = node->next;
}
delete node;
--element_count_;
return true;
}
prev = node;
node = node->next;
}
return false; // 未找到键
}

清空(Clear)

删除哈希表中的所有元素,释放内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::clear() {
for (size_type i = 0; i < bucket_count_; ++i) {
HashNode<Key, T>* node = buckets_[i];
while (node != nullptr) {
HashNode<Key, T>* temp = node;
node = node->next;
delete temp;
}
buckets_[i] = nullptr;
}
element_count_ = 0;
}

动态扩容(Rehashing)

当负载因子超过阈值时,扩展哈希表容量并重新分配所有元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::rehash() {
size_type new_bucket_count = bucket_count_ * 2;
std::vector<HashNode<Key, T>*> new_buckets(new_bucket_count, nullptr);

// 重新分配所有元素
for (size_type i = 0; i < bucket_count_; ++i) {
HashNode<Key, T>* node = buckets_[i];
while (node != nullptr) {
HashNode<Key, T>* next_node = node->next;
size_type new_index = hash_func_(node->data.first) % new_bucket_count;

// 插入到新桶的头部
node->next = new_buckets[new_index];
new_buckets[new_index] = node;

node = next_node;
}
}

// 替换旧桶
buckets_ = std::move(new_buckets);
bucket_count_ = new_bucket_count;
}

获取大小和状态

1
2
3
4
5
6
7
8
9
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::size_type MyHashMap<Key, T, Hash>::size() const {
return element_count_;
}

template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::empty() const {
return element_count_ == 0;
}

4. 迭代器的实现

为了支持迭代器操作,使 MyHashMap 能够像标准容器一样被遍历,我们需要实现一个内部的 Iterator 类。

Iterator 类定义

Iterator 类需要跟踪当前桶的索引和当前节点指针。它还需要能够找到下一个有效的节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Iterator 构造函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::Iterator::Iterator(MyHashMap* map, size_type bucket_index, HashNode<Key, T>* node)
: map_(map), bucket_index_(bucket_index), current_node_(node) {}

// 解引用操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator::reference
MyHashMap<Key, T, Hash>::Iterator::operator*() const {
if (current_node_ == nullptr) {
throw std::out_of_range("Iterator out of range");
}
return current_node_->data;
}

// 成员访问操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator::pointer
MyHashMap<Key, T, Hash>::Iterator::operator->() const {
if (current_node_ == nullptr) {
throw std::out_of_range("Iterator out of range");
}
return &(current_node_->data);
}

// 前置递增操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator&
MyHashMap<Key, T, Hash>::Iterator::operator++() {
advance();
return *this;
}

// 后置递增操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::Iterator::operator++(int) {
Iterator temp = *this;
advance();
return temp;
}

// 比较操作符==
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::Iterator::operator==(const Iterator& other) const {
return map_ == other.map_ &&
bucket_index_ == other.bucket_index_ &&
current_node_ == other.current_node_;
}

// 比较操作符!=
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::Iterator::operator!=(const Iterator& other) const {
return !(*this == other);
}

// advance 函数:移动到下一个有效节点
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::Iterator::advance() {
if (current_node_ != nullptr) {
current_node_ = current_node_->next;
}

while (current_node_ == nullptr && bucket_index_ + 1 < map_->bucket_count_) {
++bucket_index_;
current_node_ = map_->buckets_[bucket_index_];
}
}

迭代器操作

在 MyHashMap 类中实现 begin() 和 end() 函数来返回迭代器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// begin() 函数
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::begin() {
for (size_type i = 0; i < bucket_count_; ++i) {
if (buckets_[i] != nullptr) {
return Iterator(this, i, buckets_[i]);
}
}
return end();
}

// end() 函数
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::end() {
return Iterator(this, bucket_count_, nullptr);
}

5. 完整代码示例

以下是完整的 MyHashMap 实现,包括所有上述内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
#include <vector>
#include <list>
#include <utility> // For std::pair
#include <functional> // For std::hash
#include <iterator> // For iterator_traits
#include <stdexcept> // For exceptions
#include <iostream>

// HashNode 结构定义
template <typename Key, typename T>
struct HashNode {
std::pair<const Key, T> data;
HashNode* next;

HashNode(const Key& key, const T& value)
: data(std::make_pair(key, value)), next(nullptr) {}
};

// MyHashMap 类定义
template <typename Key, typename T, typename Hash = std::hash<Key>>
class MyHashMap {
public:
// 迭代器类前向声明
class Iterator;

// 类型定义
using key_type = Key;
using mapped_type = T;
using value_type = std::pair<const Key, T>;
using size_type = size_t;

// 构造函数及析构函数
MyHashMap(size_type initial_capacity = 16, double max_load_factor = 0.75);
~MyHashMap();

// 禁止拷贝构造和赋值
MyHashMap(const MyHashMap&) = delete;
MyHashMap& operator=(const MyHashMap&) = delete;

// 基本操作
void insert(const Key& key, const T& value);
T* find(const Key& key);
const T* find(const Key& key) const;
bool erase(const Key& key);
size_type size() const;
bool empty() const;
void clear();

// 迭代器操作
Iterator begin();
Iterator end();

// 迭代器类
class Iterator {
public:
// 迭代器别名
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, T>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;

// 构造函数
Iterator(MyHashMap* map, size_type bucket_index, HashNode<Key, T>* node);

// 解引用操作符
reference operator*() const;
pointer operator->() const;

// 递增操作符
Iterator& operator++();
Iterator operator++(int);

// 比较操作符
bool operator==(const Iterator& other) const;
bool operator!=(const Iterator& other) const;

private:
MyHashMap* map_;
size_type bucket_index_;
HashNode<Key, T>* current_node_;

// 移动到下一个有效节点
void advance();
};

private:
std::vector<HashNode<Key, T>*> buckets_;
size_type bucket_count_;
size_type element_count_;
double max_load_factor_;
Hash hash_func_;

// 辅助函数
void rehash();
};

// 构造函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::MyHashMap(size_type initial_capacity, double max_load_factor)
: bucket_count_(initial_capacity),
element_count_(0),
max_load_factor_(max_load_factor),
hash_func_(Hash()) {
buckets_.resize(bucket_count_, nullptr);
}

// 析构函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::~MyHashMap() {
clear();
}

// 插入函数
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::insert(const Key& key, const T& value) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
node->data.second = value; // 更新值
return;
}
node = node->next;
}

// 键不存在,插入新节点到链表头部
HashNode<Key, T>* new_node = new HashNode<Key, T>(key, value);
new_node->next = buckets_[index];
buckets_[index] = new_node;
++element_count_;

// 检查负载因子,可能需要扩容
double load_factor = static_cast<double>(element_count_) / bucket_count_;
if (load_factor > max_load_factor_) {
rehash();
}
}

// 查找函数(非常量版本)
template <typename Key, typename T, typename Hash>
T* MyHashMap<Key, T, Hash>::find(const Key& key) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
return &(node->data.second);
}
node = node->next;
}
return nullptr;
}

// 查找函数(常量版本)
template <typename Key, typename T, typename Hash>
const T* MyHashMap<Key, T, Hash>::find(const Key& key) const {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
while (node != nullptr) {
if (node->data.first == key) {
return &(node->data.second);
}
node = node->next;
}
return nullptr;
}

// 删除函数
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::erase(const Key& key) {
size_type hash_value = hash_func_(key);
size_type index = hash_value % bucket_count_;

HashNode<Key, T>* node = buckets_[index];
HashNode<Key, T>* prev = nullptr;

while (node != nullptr) {
if (node->data.first == key) {
if (prev == nullptr) {
buckets_[index] = node->next;
} else {
prev->next = node->next;
}
delete node;
--element_count_;
return true;
}
prev = node;
node = node->next;
}
return false; // 未找到键
}

// 清空函数
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::clear() {
for (size_type i = 0; i < bucket_count_; ++i) {
HashNode<Key, T>* node = buckets_[i];
while (node != nullptr) {
HashNode<Key, T>* temp = node;
node = node->next;
delete temp;
}
buckets_[i] = nullptr;
}
element_count_ = 0;
}

// 动态扩容函数
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::rehash() {
size_type new_bucket_count = bucket_count_ * 2;
std::vector<HashNode<Key, T>*> new_buckets(new_bucket_count, nullptr);

// 重新分配所有元素
for (size_type i = 0; i < bucket_count_; ++i) {
HashNode<Key, T>* node = buckets_[i];
while (node != nullptr) {
HashNode<Key, T>* next_node = node->next;
size_type new_index = hash_func_(node->data.first) % new_bucket_count;

// 插入到新桶的头部
node->next = new_buckets[new_index];
new_buckets[new_index] = node;

node = next_node;
}
}

// 替换旧桶
buckets_ = std::move(new_buckets);
bucket_count_ = new_bucket_count;
}

// begin() 函数
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::begin() {
for (size_type i = 0; i < bucket_count_; ++i) {
if (buckets_[i] != nullptr) {
return Iterator(this, i, buckets_[i]);
}
}
return end();
}

// end() 函数
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::end() {
return Iterator(this, bucket_count_, nullptr);
}

// Iterator 构造函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::Iterator::Iterator(MyHashMap* map, size_type bucket_index, HashNode<Key, T>* node)
: map_(map), bucket_index_(bucket_index), current_node_(node) {}

// 解引用操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator::reference
MyHashMap<Key, T, Hash>::Iterator::operator*() const {
if (current_node_ == nullptr) {
throw std::out_of_range("Iterator out of range");
}
return current_node_->data;
}

// 成员访问操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator::pointer
MyHashMap<Key, T, Hash>::Iterator::operator->() const {
if (current_node_ == nullptr) {
throw std::out_of_range("Iterator out of range");
}
return &(current_node_->data);
}

// 前置递增操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator&
MyHashMap<Key, T, Hash>::Iterator::operator++() {
advance();
return *this;
}

// 后置递增操作符
template <typename Key, typename T, typename Hash>
typename MyHashMap<Key, T, Hash>::Iterator
MyHashMap<Key, T, Hash>::Iterator::operator++(int) {
Iterator temp = *this;
advance();
return temp;
}

// 比较操作符==
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::Iterator::operator==(const Iterator& other) const {
return map_ == other.map_ &&
bucket_index_ == other.bucket_index_ &&
current_node_ == other.current_node_;
}

// 比较操作符!=
template <typename Key, typename T, typename Hash>
bool MyHashMap<Key, T, Hash>::Iterator::operator!=(const Iterator& other) const {
return !(*this == other);
}

// advance 函数:移动到下一个有效节点
template <typename Key, typename T, typename Hash>
void MyHashMap<Key, T, Hash>::Iterator::advance() {
if (current_node_ != nullptr) {
current_node_ = current_node_->next;
}

while (current_node_ == nullptr && bucket_index_ + 1 < map_->bucket_count_) {
++bucket_index_;
current_node_ = map_->buckets_[bucket_index_];
}
}

6. 使用示例

以下是一个使用 MyHashMap 的示例,展示如何插入、查找、删除以及使用迭代器遍历元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int main() {
MyHashMap<std::string, int> myMap;

// 插入元素
myMap.insert("apple", 3);
myMap.insert("banana", 5);
myMap.insert("orange", 2);
myMap.insert("grape", 7);
myMap.insert("cherry", 4);

// 使用迭代器遍历元素
std::cout << "Map contents:\n";
for(auto it = myMap.begin(); it != myMap.end(); ++it) {
std::cout << it->first << " => " << it->second << "\n";
}

// 查找元素
std::string keyToFind = "banana";
int* value = myMap.find(keyToFind);
if(value != nullptr) {
std::cout << "\nFound " << keyToFind << " with value: " << *value << "\n";
} else {
std::cout << "\n" << keyToFind << " not found.\n";
}

// 删除元素
myMap.erase("apple");
myMap.erase("cherry");

// 再次遍历
std::cout << "\nAfter erasing apple and cherry:\n";
for(auto it = myMap.begin(); it != myMap.end(); ++it) {
std::cout << it->first << " => " << it->second << "\n";
}

return 0;
}

预期输出

1
2
3
4
5
6
7
8
9
10
11
12
13
Map contents:
cherry => 4
banana => 5
apple => 3
grape => 7
orange => 2

Found banana with value: 5

After erasing apple and cherry:
banana => 5
grape => 7
orange => 2

注意:由于哈希表的桶顺序依赖于哈希函数的实现,输出顺序可能与预期有所不同。

Hash() 在上述 MyHashMap 实现中是一个哈希函数对象。让我们详细解释一下它的含义以及它在代码中的作用。

Hash 是什么?

在 MyHashMap 的模板定义中,Hash 是一个模板参数,用于指定键类型 Key 的哈希函数。它有一个默认值 std::hash<Key>,这意味着如果用户在实例化 MyHashMap 时没有提供自定义的哈希函数,std::hash<Key> 将被使用。

1
2
3
4
template <typename Key, typename T, typename Hash = std::hash<Key>>
class MyHashMap {
// ...
};

默认情况下:std::hash<Key>

std::hash 是 C++ 标准库(STL)中提供的一个模板结构,用于为各种内置类型(如 int, std::string 等)生成哈希值。std::hash<Key> 会根据 Key 的类型自动选择合适的哈希函数实现。

例如:

  • 对于 int 类型,std::hash<int> 会生成一个简单的哈希值。
  • 对于 std::string 类型,std::hash<std::string> 会基于字符串内容生成哈希值。
1
2
3
4
5
std::hash<int> intHasher;
size_t hashValue = intHasher(42); // 生成整数 42 的哈希值

std::hash<std::string> stringHasher;
size_t stringHash = stringHasher("hello"); // 生成字符串 "hello" 的哈希值

自定义哈希函数

除了使用 std::hash,用户还可以自定义哈希函数,以适应特定的需求或优化性能。例如,假设你有一个自定义的键类型 Point:

1
2
3
4
5
6
7
8
struct Point {
int x;
int y;

bool operator==(const Point& other) const {
return x == other.x && y == other.y;
}
};

你可以定义一个自定义的哈希函数 PointHasher:

1
2
3
4
5
6
struct PointHasher {
size_t operator()(const Point& p) const {
// 简单的哈希组合,实际应用中应选择更好的哈希组合方法
return std::hash<int>()(p.x) ^ (std::hash<int>()(p.y) << 1);
}
};

然后,在实例化 MyHashMap 时使用自定义哈希函数:

1
2
3
MyHashMap<Point, std::string, PointHasher> pointMap;
Point p1{1, 2};
pointMap.insert(p1, "Point1");

Hash() 在代码中的作用

在 MyHashMap 的构造函数中,Hash() 用于实例化哈希函数对象,并将其赋值给成员变量 hash_func_:

1
2
3
4
5
6
7
8
9
// 构造函数
template <typename Key, typename T, typename Hash>
MyHashMap<Key, T, Hash>::MyHashMap(size_type initial_capacity, double max_load_factor)
: bucket_count_(initial_capacity),
element_count_(0),
max_load_factor_(max_load_factor),
hash_func_(Hash()) { // 这里的 Hash() 是一个默认构造函数调用
buckets_.resize(bucket_count_, nullptr);
}

具体作用:

  1. 实例化哈希函数对象:
    • Hash() 会调用 Hash 类型的默认构造函数,创建一个哈希函数对象。
    • 如果 Hash 是 std::hash<Key>,那么就创建一个 std::hash<Key> 对象。
  2. 存储哈希函数对象:
    • 生成的哈希函数对象被存储在成员变量 hash_func_ 中,以便在哈希表的各种操作(如插入、查找、删除)中使用。
  3. 支持不同的哈希函数:
    • 由于 Hash 是一个模板参数,可以灵活地使用不同的哈希函数,无需修改 MyHashMap 的内部实现。

示例代码解释

以下是相关部分的简化示例,帮助理解 Hash 和 Hash() 的作用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <vector>
#include <functional> // For std::hash

template <typename Key, typename T, typename Hash = std::hash<Key>>
class MyHashMap {
public:
MyHashMap(size_t initial_capacity = 16, double max_load_factor = 0.75)
: bucket_count_(initial_capacity),
element_count_(0),
max_load_factor_(max_load_factor),
hash_func_(Hash()) { // 实例化哈希函数对象
buckets_.resize(bucket_count_, nullptr);
}

void insert(const Key& key, const T& value) {
size_t hash_value = hash_func_(key); // 使用哈希函数对象
size_t index = hash_value % bucket_count_;
// 插入逻辑...
}

private:
std::vector<HashNode<Key, T>*> buckets_;
size_t bucket_count_;
size_t element_count_;
double max_load_factor_;
Hash hash_func_; // 存储哈希函数对象
};

总结

  • Hash 是模板参数,用于指定键类型 Key 的哈希函数。默认情况下,它使用 C++ 标准库中的 std::hash<Key>。
  • Hash() 是一个默认构造函数调用,用于实例化哈希函数对象,并将其存储在 hash_func_ 成员变量中,以便在哈希表操作中使用。
  • 用户可以自定义哈希函数,通过提供自定义的哈希函数对象,实现对特定键类型的优化或满足特殊需求。

手写线程安全智能指针

Posted on 2024-12-27 | In 零基础C++

现有 SimpleSharedPtr 的线程安全性分析

在多线程环境下,确保智能指针的线程安全性主要涉及以下几个方面:

  1. 引用计数管理:多个线程可能会同时拷贝、移动或销毁智能指针实例,导致引用计数的修改。若引用计数不是原子操作,则会引发数据竞争和未定义行为。
  2. 指针和控制块的访问:多个线程可能会同时访问或修改同一个智能指针实例的 ptr 和 control 成员,这需要同步机制来保护。

当前 SimpleSharedPtr 的问题:

  • 引用计数非原子:ControlBlock::ref_count 是普通的 int 类型,当多个线程同时修改 ref_count 时,会引发竞态条件。
  • 缺乏同步机制:SimpleSharedPtr 的成员函数(如拷贝构造、赋值操作符等)在修改 ptr 和 control 时没有任何同步机制,导致多个线程同时操作同一个 SimpleSharedPtr 实例时不安全。

实现线程安全的 SimpleSharedPtr

为了解决上述问题,可以从以下几个方面入手:

方法一:使用 std::atomic 管理引用计数

将 ControlBlock::ref_count 从普通的 int 替换为 std::atomic<int>,以确保引用计数的线程安全递增和递减。

优点:

  • 简单高效,避免使用互斥锁带来的性能开销。
  • 类似于标准库中 std::shared_ptr 实现的引用计数管理。

缺点:

  • 只能保证引用计数本身的线程安全,无法保护 ptr 和 control 的同步访问。

方法二:引入互斥锁保护指针操作

在 SimpleSharedPtr 中引入 std::mutex,在所有可能修改 ptr 和 control 的操作中加锁。

优点:

  • 确保 ptr 和 control 在多线程访问时的一致性。
  • 提供更全面的线程安全保障。

缺点:

  • 引入锁机制,可能带来性能开销,特别是在高并发场景下。

方法三:组合使用 std::atomic 和互斥锁

结合使用 std::atomic<int> 进行引用计数的管理,并使用 std::mutex 保护指针和控制块的访问。

优点:

  • 引用计数管理高效且线程安全。
  • 指针和控制块的访问得到完全的同步保护。

缺点:

  • 复杂性较高,需要同时管理原子操作和互斥锁。

完整线程安全的 ThreadSafeSharedPtr 实现

结合上述方法二和方法一,我们可以实现一个名为 ThreadSafeSharedPtr 的类模板,确保在多线程环境下的安全性。以下是具体实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#include <iostream>
#include <atomic>
#include <mutex>
#include <thread>

// 控制块结构
struct ControlBlock {
std::atomic<int> ref_count;

ControlBlock() : ref_count(1) {}
};

// 线程安全的 shared_ptr 实现
template <typename T>
class ThreadSafeSharedPtr {
private:
T* ptr; // 指向管理的对象
ControlBlock* control; // 指向控制块

// 互斥锁,用于保护 ptr 和 control
mutable std::mutex mtx;

// 释放当前资源
void release() {
if (control) {
// 原子递减引用计数
if (--(control->ref_count) == 0) {
delete ptr;
delete control;
std::cout << "Resource and ControlBlock destroyed." << std::endl;
} else {
std::cout << "Decremented ref_count to " << control->ref_count.load() << std::endl;
}
}
ptr = nullptr;
control = nullptr;
}

public:
// 默认构造函数
ThreadSafeSharedPtr() : ptr(nullptr), control(nullptr) {
std::cout << "Default constructed ThreadSafeSharedPtr (nullptr)." << std::endl;
}

// 参数化构造函数
explicit ThreadSafeSharedPtr(T* p) : ptr(p) {
if (p) {
control = new ControlBlock();
std::cout << "Constructed ThreadSafeSharedPtr, ref_count = " << control->ref_count.load() << std::endl;
} else {
control = nullptr;
}
}

// 拷贝构造函数
ThreadSafeSharedPtr(const ThreadSafeSharedPtr& other) {
std::lock_guard<std::mutex> lock(other.mtx);
ptr = other.ptr;
control = other.control;
if (control) {
control->ref_count++;
std::cout << "Copied ThreadSafeSharedPtr, ref_count = " << control->ref_count.load() << std::endl;
}
}

// 拷贝赋值操作符
ThreadSafeSharedPtr& operator=(const ThreadSafeSharedPtr& other) {
if (this != &other) {
// 为避免死锁,使用 std::scoped_lock 同时锁定两个互斥锁
std::scoped_lock lock(mtx, other.mtx);
release();
ptr = other.ptr;
control = other.control;
if (control) {
control->ref_count++;
std::cout << "Assigned ThreadSafeSharedPtr, ref_count = " << control->ref_count.load() << std::endl;
}
}
return *this;
}

// 移动构造函数
ThreadSafeSharedPtr(ThreadSafeSharedPtr&& other) noexcept {
std::lock_guard<std::mutex> lock(other.mtx);
ptr = other.ptr;
control = other.control;
other.ptr = nullptr;
other.control = nullptr;
std::cout << "Moved ThreadSafeSharedPtr." << std::endl;
}

// 移动赋值操作符
ThreadSafeSharedPtr& operator=(ThreadSafeSharedPtr&& other) noexcept {
if (this != &other) {
// 为避免死锁,使用 std::scoped_lock 同时锁定两个互斥锁
std::scoped_lock lock(mtx, other.mtx);
release();
ptr = other.ptr;
control = other.control;
other.ptr = nullptr;
other.control = nullptr;
std::cout << "Move-assigned ThreadSafeSharedPtr." << std::endl;
}
return *this;
}

// 析构函数
~ThreadSafeSharedPtr() {
release();
}

// 解引用操作符
T& operator*() const {
std::lock_guard<std::mutex> lock(mtx);
return *ptr;
}

// 箭头操作符
T* operator->() const {
std::lock_guard<std::mutex> lock(mtx);
return ptr;
}

// 获取引用计数
int use_count() const {
std::lock_guard<std::mutex> lock(mtx);
return control ? control->ref_count.load() : 0;
}

// 获取裸指针
T* get() const {
std::lock_guard<std::mutex> lock(mtx);
return ptr;
}

// 重置指针
void reset(T* p = nullptr) {
std::lock_guard<std::mutex> lock(mtx);
release();
ptr = p;
if (p) {
control = new ControlBlock();
std::cout << "Reset ThreadSafeSharedPtr, ref_count = " << control->ref_count.load() << std::endl;
} else {
control = nullptr;
}
}
};

// 测试类
class Test {
public:
Test(int val) : value(val) {
std::cout << "Test Constructor: " << value << std::endl;
}
~Test() {
std::cout << "Test Destructor: " << value << std::endl;
}
void show() const {
std::cout << "Value: " << value << std::endl;
}

private:
int value;
};

关键改动说明

  1. 引用计数原子化:

    • 将

      1
      ControlBlock::ref_count

      从普通的

      1
      int

      改为

      1
      std::atomic<int>

      :

      1
      std::atomic<int> ref_count;
    • 使用原子操作管理引用计数,确保多线程下的安全递增和递减:

      1
      2
      control->ref_count++;
      if (--(control->ref_count) == 0) { ... }
    • 使用 ref_count.load() 获取当前引用计数的值。

  2. 引入互斥锁:

    • 在

      1
      ThreadSafeSharedPtr

      中引入

      1
      std::mutex mtx

      ,用于保护

      1
      ptr

      和

      1
      control

      的访问:

      1
      mutable std::mutex mtx;
    • 在所有可能修改或访问

      1
      ptr

      和

      1
      control

      的成员函数中加锁,确保同步:

      1
      std::lock_guard<std::mutex> lock(mtx);
    • 在拷贝构造函数和拷贝赋值操作符中,为避免死锁,使用

      1
      std::scoped_lock

      同时锁定两个互斥锁:

      1
      std::scoped_lock lock(mtx, other.mtx);
  3. 线程安全的成员函数:

    • 对于 operator* 和 operator->,在返回前锁定互斥锁,确保在多线程环境中的安全访问。
    • 其他成员函数如 use_count、get 和 reset 同样在访问共享资源前加锁。

注意事项

  • 避免死锁:在需要同时锁定多个互斥锁时,使用 std::scoped_lock(C++17 引入)可以同时锁定多个互斥锁,避免死锁风险。
  • 性能开销:引入互斥锁会带来一定的性能开销,尤其是在高并发场景下。根据实际需求,权衡线程安全性和性能之间的关系。

测试线程安全的 ThreadSafeSharedPtr

为了验证 ThreadSafeSharedPtr 的线程安全性,我们可以编写一个多线程程序,让多个线程同时拷贝、赋值和销毁智能指针。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <iostream>
#include <thread>
#include <vector>
#include "ThreadSafeSharedPtr.h" // 假设将上述代码保存为该头文件

// 测试类
class Test {
public:
Test(int val) : value(val) {
std::cout << "Test Constructor: " << value << std::endl;
}
~Test() {
std::cout << "Test Destructor: " << value << std::endl;
}
void show() const {
std::cout << "Value: " << value << std::endl;
}

private:
int value;
};

void thread_func_copy(ThreadSafeSharedPtr<Test> sptr, int thread_id) {
std::cout << "Thread " << thread_id << " is copying shared_ptr." << std::endl;
ThreadSafeSharedPtr<Test> local_sptr = sptr;
std::cout << "Thread " << thread_id << " copied shared_ptr, use_count = " << local_sptr.use_count() << std::endl;
local_sptr->show();
}

void thread_func_reset(ThreadSafeSharedPtr<Test>& sptr, int new_val, int thread_id) {
std::cout << "Thread " << thread_id << " is resetting shared_ptr." << std::endl;
sptr.reset(new Test(new_val));
std::cout << "Thread " << thread_id << " reset shared_ptr, use_count = " << sptr.use_count() << std::endl;
sptr->show();
}

int main() {
std::cout << "Creating ThreadSafeSharedPtr with Test(100)." << std::endl;
ThreadSafeSharedPtr<Test> sptr(new Test(100));
std::cout << "Initial use_count: " << sptr.use_count() << std::endl;

// 创建多个线程进行拷贝操作
const int num_threads = 5;
std::vector<std::thread> threads_copy;

for(int i = 0; i < num_threads; ++i) {
threads_copy.emplace_back(thread_func_copy, sptr, i);
}

for(auto& t : threads_copy) {
t.join();
}

std::cout << "After copy threads, use_count: " << sptr.use_count() << std::endl;

// 创建多个线程进行 reset 操作
std::vector<std::thread> threads_reset;

for(int i = 0; i < num_threads; ++i) {
threads_reset.emplace_back(thread_func_reset, std::ref(sptr), 200 + i, i);
}

for(auto& t : threads_reset) {
t.join();
}

std::cout << "After reset threads, final use_count: " << sptr.use_count() << std::endl;

std::cout << "Exiting main." << std::endl;
return 0;
}

预期输出示例(具体顺序可能因线程调度而异):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Creating ThreadSafeSharedPtr with Test(100).
Test Constructor: 100
Constructed ThreadSafeSharedPtr, ref_count = 1
Initial use_count: 1
Thread 0 is copying shared_ptr.
Copied ThreadSafeSharedPtr, ref_count = 2
Thread 0 copied shared_ptr, use_count = 2
Value: 100
Thread 1 is copying shared_ptr.
Copied ThreadSafeSharedPtr, ref_count = 3
Thread 1 copied shared_ptr, use_count = 3
Value: 100
...
After copy threads, use_count: 6
Thread 0 is resetting shared_ptr.
Decremented ref_count to 5
Resource and ControlBlock destroyed.
Test Constructor: 200
Reset ThreadSafeSharedPtr, ref_count = 1
Value: 200
...
After reset threads, final use_count: 1
Exiting main.
Test Destructor: 200

说明:

  • 多个线程同时拷贝 sptr,引用计数正确递增。
  • 多个线程同时重置 sptr,确保引用计数和资源管理的正确性。
  • 最终,只有最新分配的对象存在,引用计数为 1。

注意事项和最佳实践

  1. 引用计数的原子性:
    • 使用 std::atomic<int> 来保证引用计数的线程安全递增和递减。
    • 避免使用普通的 int,因为在多线程环境下会导致数据竞争。
  2. 互斥锁的使用:
    • 使用 std::mutex 来保护 ptr 和 control 的访问,防止多个线程同时修改智能指针实例。
    • 尽量缩小锁的范围,避免在互斥锁保护的临界区内执行耗时操作,以减少性能开销。
  3. 避免死锁:
    • 在需要同时锁定多个互斥锁时,使用 std::scoped_lock 来一次性锁定,确保锁的顺序一致,避免死锁风险。
  4. 尽量遵循 RAII 原则:
    • 使用 std::lock_guard 或 std::scoped_lock 等 RAII 机制来管理互斥锁,确保在异常抛出时自动释放锁,防止死锁。
  5. 避免多重管理:
    • 确保不通过裸指针绕过智能指针的引用计数管理,避免资源泄漏或重复释放。
  6. 性能考虑:
    • 在高并发场景下,频繁的锁操作可能成为性能瓶颈。根据实际需求,可以考虑使用更轻量级的同步机制,如 std::shared_mutex(C++17)用于读多写少的场景。

总结

通过将 ControlBlock::ref_count 改为 std::atomic<int>,并在 ThreadSafeSharedPtr 中引入互斥锁来保护 ptr 和 control 的访问,可以实现一个线程安全的智能指针。这种实现确保了在多线程环境下,多个线程可以安全地拷贝、赋值和销毁智能指针,同时正确管理引用计数和资源。

关键点总结:

  • 引用计数的原子性:使用 std::atomic<int> 保证引用计数操作的线程安全。
  • 互斥锁保护:使用 std::mutex 保护智能指针实例的内部状态,防止多个线程同时修改。
  • RAII 机制:利用 std::lock_guard 和 std::scoped_lock 等 RAII 机制,确保锁的正确管理和释放。
  • 避免死锁:在需要同时锁定多个互斥锁时,使用 std::scoped_lock 以避免死锁风险。
  • 性能优化:平衡线程安全性和性能,避免不必要的锁竞争。

map用法和手写map

Posted on 2024-12-27 | In 零基础C++

std::map用法

std::map 是 C++ 标准模板库(STL)中的一个关联容器,用于存储键值对(key-value pairs),其中每个键都是唯一的,并且按照特定的顺序(通常是升序)自动排序。std::map 通常基于红黑树实现,提供对元素的高效查找、插入和删除操作。

1. 基本特性

  • 有序性:std::map 中的元素按照键的顺序自动排序,默认使用 < 运算符进行比较。
  • 唯一键:每个键在 std::map 中必须是唯一的,如果尝试插入重复的键,则插入操作会失败。
  • 关联容器:通过键快速访问对应的值,通常具有对数时间复杂度(O(log n))。
  • 可变性:可以动态地插入和删除元素。

2. 头文件和命名空间

要使用 std::map,需要包含头文件 <map> 并使用 std 命名空间:

1
2
3
4
5
#include <map>
#include <iostream>
#include <string>

using namespace std;

3. 声明和初始化

3.1 声明一个 std::map

1
2
3
4
5
// 键为 int,值为 std::string 的 map
map<int, string> myMap;

// 键为 std::string,值为 double 的 map
map<string, double> priceMap;

3.2 初始化 std::map

可以使用初始化列表或其他方法初始化 map:

1
2
3
4
5
map<int, string> myMap = {
{1, "Apple"},
{2, "Banana"},
{3, "Cherry"}
};

4. 主要操作

4.1 插入元素

有几种方法可以向 std::map 中插入元素:

4.1.1 使用 insert 函数

1
2
3
4
5
myMap.insert(pair<int, string>(4, "Date"));
// 或者使用 `make_pair`
myMap.insert(make_pair(5, "Elderberry"));
// 或者使用初始化列表
myMap.insert({6, "Fig"});

4.1.2 使用下标运算符 []

1
2
3
myMap[7] = "Grape";
// 如果键 8 不存在,则会插入键 8 并赋值
myMap[8] = "Honeydew";

注意:使用 [] 运算符时,如果键不存在,会自动插入该键,并将值初始化为类型的默认值。

4.2 访问元素

4.2.1 使用下标运算符 []

1
string fruit = myMap[1]; // 获取键为 1 的值 "Apple"

注意:如果键不存在,[] 会插入该键并返回默认值。

4.2.2 使用 at 成员函数

1
2
3
4
5
try {
string fruit = myMap.at(2); // 获取键为 2 的值 "Banana"
} catch (const out_of_range& e) {
cout << "Key not found." << endl;
}

at 函数在键不存在时会抛出 std::out_of_range 异常,适合需要异常处理的场景。

4.2.3 使用 find 成员函数

1
2
3
4
5
6
auto it = myMap.find(3);
if (it != myMap.end()) {
cout << "Key 3: " << it->second << endl; // 输出 "Cherry"
} else {
cout << "Key 3 not found." << endl;
}

find 返回一个迭代器,指向找到的元素,若未找到则返回 map::end()。

4.3 删除元素

4.3.1 使用 erase 函数

1
2
3
4
5
6
7
8
9
10
11
// 按键删除
myMap.erase(2);

// 按迭代器删除
auto it = myMap.find(3);
if (it != myMap.end()) {
myMap.erase(it);
}

// 删除区间 [first, last)
myMap.erase(myMap.begin(), myMap.find(5));

4.3.2 使用 clear 函数

1
myMap.clear(); // 删除所有元素

4.4 遍历 std::map

4.4.1 使用迭代器

1
2
3
for (map<int, string>::iterator it = myMap.begin(); it != myMap.end(); ++it) {
cout << "Key: " << it->first << ", Value: " << it->second << endl;
}

4.4.2 使用基于范围的 for 循环(C++11 及以上)

1
2
3
for (const auto& pair : myMap) {
cout << "Key: " << pair.first << ", Value: " << pair.second << endl;
}

4.5 常用成员函数

  • **size()**:返回容器中元素的数量。
  • **empty()**:判断容器是否为空。
  • **count(key)**:返回具有指定键的元素数量(对于 map,返回 0 或 1)。
  • lower_bound(key) 和 **upper_bound(key)**:返回迭代器,分别指向第一个不小于和第一个大于指定键的元素。
  • **equal_range(key)**:返回一个范围,包含所有等于指定键的元素。

5. 自定义键的排序

默认情况下,std::map 使用 < 运算符对键进行排序。如果需要自定义排序方式,可以提供一个自定义的比较函数或函数对象。

5.1 使用函数对象

1
2
3
4
5
6
7
8
9
10
struct Compare {
bool operator()(const int& a, const int& b) const {
return a > b; // 降序排序
}
};

map<int, string, Compare> myMapDesc;
myMapDesc[1] = "Apple";
myMapDesc[2] = "Banana";
// ...

5.2 使用 lambda 表达式(C++11 及以上)

需要注意,std::map 的第三个模板参数必须是可比较类型,不能直接使用 lambda 表达式作为模板参数。不过,可以使用 std::function 或自定义结构体来包装 lambda。

1
2
3
4
5
6
7
8
9
struct CompareLambda {
bool operator()(const int& a, const int& b) const {
return a > b; // 降序排序
}
};

map<int, string, CompareLambda> myMapDesc;
myMapDesc[1] = "Apple";
// ...

6. std::map 与其他关联容器的比较

  • **std::unordered_map**:基于哈希表实现,提供平均常数时间复杂度的查找、插入和删除操作,但不保证元素的顺序。适用于对顺序无要求且需要高效查找的场景。
  • **std::multimap**:允许多个相同键的元素,其他特性与 std::map 类似。适用于需要存储重复键值对的场景。

7. 性能考虑

  • 时间复杂度

    :

    • 查找、插入、删除:O(log n)
    • 遍历:O(n)
  • 空间复杂度:std::map 通常需要额外的空间来维护树结构,相比 std::vector 等序列容器,内存开销更大。

选择使用 std::map 还是其他容器,应根据具体需求和性能要求进行权衡。

8. 完整示例

以下是一个完整的示例,展示了 std::map 的基本用法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <iostream>
#include <map>
#include <string>

using namespace std;

int main() {
// 创建一个 map,键为 int,值为 string
map<int, string> myMap;

// 插入元素
myMap[1] = "Apple";
myMap[2] = "Banana";
myMap.insert({3, "Cherry"});
myMap.insert(make_pair(4, "Date"));

// 访问元素
cout << "Key 1: " << myMap[1] << endl;
cout << "Key 2: " << myMap.at(2) << endl;

// 查找元素
int keyToFind = 3;
auto it = myMap.find(keyToFind);
if (it != myMap.end()) {
cout << "Found key " << keyToFind << ": " << it->second << endl;
} else {
cout << "Key " << keyToFind << " not found." << endl;
}

// 遍历 map
cout << "All elements:" << endl;
for (const auto& pair : myMap) {
cout << "Key: " << pair.first << ", Value: " << pair.second << endl;
}

// 删除元素
myMap.erase(2);
cout << "After deleting key 2:" << endl;
for (const auto& pair : myMap) {
cout << "Key: " << pair.first << ", Value: " << pair.second << endl;
}

// 检查是否为空
if (!myMap.empty()) {
cout << "Map is not empty. Size: " << myMap.size() << endl;
}

// 清空所有元素
myMap.clear();
cout << "After clearing, map is " << (myMap.empty() ? "empty." : "not empty.") << endl;

return 0;
}

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Key 1: Apple
Key 2: Banana
Found key 3: Cherry
All elements:
Key: 1, Value: Apple
Key: 2, Value: Banana
Key: 3, Value: Cherry
Key: 4, Value: Date
After deleting key 2:
Key: 1, Value: Apple
Key: 3, Value: Cherry
Key: 4, Value: Date
Map is not empty. Size: 3
After clearing, map is empty.

BST实现map

1. 选择底层数据结构

std::map 通常基于平衡的二叉搜索树(如红黑树)实现,以保证操作的时间复杂度为 O(log n)。为了简化实现,本文将采用普通的二叉搜索树,即不进行自平衡处理。不过在实际应用中,为了保证性能,建议使用自平衡的树结构(例如红黑树、AVL 树)。

2. 设计数据结构

2.1 节点结构

首先,需要定义树的节点结构,每个节点包含键值对以及指向子节点的指针。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <iostream>
#include <stack>
#include <utility> // For std::pair
#include <exception>

template <typename Key, typename T>
struct TreeNode {
std::pair<Key, T> data;
TreeNode* left;
TreeNode* right;
TreeNode* parent;

TreeNode(const Key& key, const T& value, TreeNode* parentNode = nullptr)
: data(std::make_pair(key, value)), left(nullptr), right(nullptr), parent(parentNode) {}
};

2.2 Map 类的定义

接下来,定义 MyMap 类,包含根节点和基本操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
template <typename Key, typename T>
class MyMap {
public:
MyMap() : root(nullptr) {}
~MyMap() { clear(root); }

// 禁止拷贝构造和赋值
MyMap(const MyMap&) = delete;
MyMap& operator=(const MyMap&) = delete;

// 插入或更新键值对
void insert(const Key& key, const T& value) {
if (root == nullptr) {
root = new TreeNode<Key, T>(key, value);
return;
}

TreeNode<Key, T>* current = root;
TreeNode<Key, T>* parent = nullptr;

while (current != nullptr) {
parent = current;
if (key < current->data.first) {
current = current->left;
} else if (key > current->data.first) {
current = current->right;
} else {
// 键已存在,更新值
current->data.second = value;
return;
}
}

if (key < parent->data.first) {
parent->left = new TreeNode<Key, T>(key, value, parent);
} else {
parent->right = new TreeNode<Key, T>(key, value, parent);
}
}

// 查找元素,返回指向节点的指针
TreeNode<Key, T>* find(const Key& key) const {
TreeNode<Key, T>* current = root;
while (current != nullptr) {
if (key < current->data.first) {
current = current->left;
} else if (key > current->data.first) {
current = current->right;
} else {
return current;
}
}
return nullptr;
}

// 删除元素
void erase(const Key& key) {
TreeNode<Key, T>* node = find(key);
if (node == nullptr) return; // 没有找到要删除的节点

// 节点有两个子节点
if (node->left != nullptr && node->right != nullptr) {
// 找到中序后继
TreeNode<Key, T>* successor = minimum(node->right);
node->data = successor->data; // 替换数据
node = successor; // 将要删除的节点指向后继节点
}

// 节点有一个或没有子节点
TreeNode<Key, T>* child = (node->left) ? node->left : node->right;
if (child != nullptr) {
child->parent = node->parent;
}

if (node->parent == nullptr) {
root = child;
} else if (node == node->parent->left) {
node->parent->left = child;
} else {
node->parent->right = child;
}

delete node;
}

// 清空所有节点
void clear() {
clear(root);
root = nullptr;
}

// 获取迭代器
class Iterator {
public:
Iterator(TreeNode<Key, T>* node = nullptr) : current(node) {}

std::pair<const Key, T>& operator*() const {
return current->data;
}

std::pair<const Key, T>* operator->() const {
return &(current->data);
}

// 前置递增
Iterator& operator++() {
current = successor(current);
return *this;
}

// 后置递增
Iterator operator++(int) {
Iterator temp = *this;
current = successor(current);
return temp;
}

bool operator==(const Iterator& other) const {
return current == other.current;
}

bool operator!=(const Iterator& other) const {
return current != other.current;
}

private:
TreeNode<Key, T>* current;

TreeNode<Key, T>* minimum(TreeNode<Key, T>* node) const {
while (node->left != nullptr) {
node = node->left;
}
return node;
}

TreeNode<Key, T>* successor(TreeNode<Key, T>* node) const {
if (node->right != nullptr) {
return minimum(node->right);
}

TreeNode<Key, T>* p = node->parent;
while (p != nullptr && node == p->right) {
node = p;
p = p->parent;
}
return p;
}
};

Iterator begin() const {
return Iterator(minimum(root));
}

Iterator end() const {
return Iterator(nullptr);
}

private:
TreeNode<Key, T>* root;

// 删除树中的所有节点
void clear(TreeNode<Key, T>* node) {
if (node == nullptr) return;
clear(node->left);
clear(node->right);
delete node;
}

// 找到最小的节点
TreeNode<Key, T>* minimum(TreeNode<Key, T>* node) const {
if (node == nullptr) return nullptr;
while (node->left != nullptr) {
node = node->left;
}
return node;
}

// 找到最大的节点
TreeNode<Key, T>* maximum(TreeNode<Key, T>* node) const {
if (node == nullptr) return nullptr;
while (node->right != nullptr) {
node = node->right;
}
return node;
}
};

2.3 解释

  1. TreeNode 结构:

    • data: 存储键值对 std::pair<Key, T>。
    • left 和 right: 指向左子节点和右子节点。
    • parent: 指向父节点,便于迭代器中查找后继节点。
  2. MyMap 类:

    • 构造与析构

      :

      • 构造函数初始化根节点为空。
      • 析构函数调用 clear 释放所有节点内存。
    • 插入 (insert)

      :

      • 从根节点开始,根据键的大小确定插入左子树还是右子树。
      • 如果键已存在,更新对应的值。
    • 查找 (find)

      :

      • 按照键的大小在树中查找对应的节点。
    • 删除 (erase)

      :

      • 查找到目标节点。
      • 如果节点有两个子节点,找到中序后继节点并替换当前节点的数据,然后删除后继节点。
      • 如果节点有一个或没有子节点,直接替换节点指针并删除节点。
    • 清空 (clear)

      :

      • 使用递归方式删除所有节点。
    • 迭代器

      :

      • 定义了嵌套的 Iterator 类,支持前置和后置递增操作。
      • 迭代器通过中序遍历实现,保证键的顺序性。
      • begin() 返回最小的节点,end() 返回 nullptr。

3. 使用示例

下面提供一个使用 MyMap 的示例,展示插入、查找、删除和迭代操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int main() {
MyMap<std::string, int> myMap;

// 插入元素
myMap.insert("apple", 3);
myMap.insert("banana", 5);
myMap.insert("orange", 2);
myMap.insert("grape", 7);
myMap.insert("cherry", 4);

// 使用迭代器遍历元素(按键的字母顺序)
std::cout << "Map contents (in-order):\n";
for(auto it = myMap.begin(); it != myMap.end(); ++it) {
std::cout << it->first << " => " << it->second << "\n";
}

// 查找元素
std::string keyToFind = "banana";
auto node = myMap.find(keyToFind);
if(node != nullptr) {
std::cout << "\nFound " << keyToFind << " with value: " << node->data.second << "\n";
} else {
std::cout << "\n" << keyToFind << " not found.\n";
}

// 删除元素
myMap.erase("apple");
myMap.erase("cherry");

// 再次遍历
std::cout << "\nAfter erasing apple and cherry:\n";
for(auto it = myMap.begin(); it != myMap.end(); ++it) {
std::cout << it->first << " => " << it->second << "\n";
}

return 0;
}

输出结果

1
2
3
4
5
6
7
8
9
10
11
12
13
Map contents (in-order):
apple => 3
banana => 5
cherry => 4
grape => 7
orange => 2

Found banana with value: 5

After erasing apple and cherry:
banana => 5
grape => 7
orange => 2

4. 迭代器的详细实现

为了支持迭代器的正常使用,Iterator 类实现了以下功能:

  • **解引用操作符 (operator\* 和 operator->)**:
    • 允许访问键值对,如 it->first 和 it->second。
  • **递增操作符 (operator++)**:
    • 前置递增(++it)和后置递增(it++)用于移动到下一个元素。
    • 通过查找当前节点的中序后继节点实现。
  • **比较操作符 (operator== 和 operator!=)**:
    • 判断两个迭代器是否指向同一个节点。

中序后继节点的查找

迭代器使用中序遍历来确保键的有序性。计算后继节点的步骤如下:

  1. 当前节点有右子树:
    • 后继节点是右子树中的最小节点。
  2. 当前节点没有右子树:
    • 向上查找,直到找到一个节点是其父节点的左子树,此时父节点即为后继节点。

如果没有后继节点(即当前节点是最大的节点),则返回 nullptr,表示迭代器到达 end()。

5. 扩展功能

上述实现是一个基本的 Map,还可以根据需要扩展更多功能,例如:

  • 支持 const 迭代器:
    • 定义 const_iterator,确保在只读操作时数据不被修改。
  • 实现平衡树:
    • 为了提高性能,可以实现红黑树、AVL 树等自平衡二叉搜索树,保证操作的时间复杂度为 O(log n)。
  • 添加更多成员函数:
    • 如 operator[]、count、lower_bound、upper_bound 等,增加容器的功能性。
  • 异常处理:
    • 增加对异常情况的处理,例如在删除不存在的键时抛出异常等。
  • 迭代器的逆向遍历:
    • 实现双向迭代器,支持逆序遍历(rbegin() 和 rend())。

AVL树

AVL树(Adelson-Velsky and Landis树)是一种自平衡的二叉搜索树(BST),它在插入和删除操作后通过旋转来保持树的平衡,从而确保基本操作(如查找、插入和删除)的时间复杂度保持在O(log n)。使用AVL树来实现map(键值对映射)是一个高效的选择,特别适合需要频繁查找、插入和删除操作的场景。

1. 模板化AVL树节点

首先,我们需要将AVLNode结构模板化,使其能够处理不同类型的键和值。我们假设键类型KeyType支持operator<进行比较,因为AVL树需要对键进行排序以维护其性质。

首先,我们定义AVL树的节点。每个节点包含一个键(key)、一个值(value)、节点高度(height),以及指向左子节点和右子节点的指针。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <iostream>
#include <string>
#include <vector>
#include <algorithm> // 用于 std::max
#include <functional> // 用于 std::function

// 模板化的AVL树节点结构
template <typename KeyType, typename ValueType>
struct AVLNode {
KeyType key;
ValueType value;
int height;
AVLNode* left;
AVLNode* right;

AVLNode(const KeyType& k, const ValueType& val)
: key(k), value(val), height(1), left(nullptr), right(nullptr) {}
};

说明:

  • KeyType:键的类型,需要支持比较操作(如 operator<)。
  • ValueType:值的类型,可以是任何类型。

2. 辅助函数的模板化

辅助函数同样需要模板化,以适应不同的AVLNode类型。

获取节点高度

获取节点的高度。如果节点为空,则高度为0。

1
2
3
4
5
6
template <typename KeyType, typename ValueType>
int getHeight(AVLNode<KeyType, ValueType>* node) {
if (node == nullptr)
return 0;
return node->height;
}

获取平衡因子

平衡因子(Balance Factor)是左子树高度减去右子树高度。

1
2
3
4
5
6
template <typename KeyType, typename ValueType>
int getBalance(AVLNode<KeyType, ValueType>* node) {
if (node == nullptr)
return 0;
return getHeight(node->left) - getHeight(node->right);
}

右旋转

右旋转用于处理左子树过高的情况(例如,左左情况)。

1
2
3
4
5
    y                             x
/ \ / \
x T3 ==> z y
/ \ / \
z T2 T2 T3

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* rightRotate(AVLNode<KeyType, ValueType>* y) {
AVLNode<KeyType, ValueType>* x = y->left;
AVLNode<KeyType, ValueType>* T2 = x->right;

// 执行旋转
x->right = y;
y->left = T2;

// 更新高度
y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;
x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;

// 返回新的根
return x;
}

左旋转

左旋转用于处理右子树过高的情况(例如,右右情况)。

1
2
3
4
5
  x                             y
/ \ / \
T1 y ==> x z
/ \ / \
T2 z T1 T2

具体实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* leftRotate(AVLNode<KeyType, ValueType>* x) {
AVLNode<KeyType, ValueType>* y = x->right;
AVLNode<KeyType, ValueType>* T2 = y->left;

// 执行旋转
y->left = x;
x->right = T2;

// 更新高度
x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;
y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;

// 返回新的根
return y;
}

3. AVL树的核心操作模板化

插入节点

插入操作遵循标准的二叉搜索树插入方式,然后通过旋转保持树的平衡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* insertNode(AVLNode<KeyType, ValueType>* node, const KeyType& key, const ValueType& value) {
// 1. 执行标准的BST插入
if (node == nullptr)
return new AVLNode<KeyType, ValueType>(key, value);

if (key < node->key)
node->left = insertNode(node->left, key, value);
else if (key > node->key)
node->right = insertNode(node->right, key, value);
else {
// 如果键已经存在,更新其值
node->value = value;
return node;
}

// 2. 更新节点高度
node->height = 1 + std::max(getHeight(node->left), getHeight(node->right));

// 3. 获取平衡因子
int balance = getBalance(node);

// 4. 根据平衡因子进行旋转

// 左左情况
if (balance > 1 && key < node->left->key)
return rightRotate(node);

// 右右情况
if (balance < -1 && key > node->right->key)
return leftRotate(node);

// 左右情况
if (balance > 1 && key > node->left->key) {
node->left = leftRotate(node->left);
return rightRotate(node);
}

// 右左情况
if (balance < -1 && key < node->right->key) {
node->right = rightRotate(node->right);
return leftRotate(node);
}

return node;
}

查找节点

按键查找节点,返回对应的值。如果键不存在,返回nullptr。

1
2
3
4
5
6
7
8
9
10
11
12
template <typename KeyType, typename ValueType>
ValueType* searchNode(AVLNode<KeyType, ValueType>* node, const KeyType& key) {
if (node == nullptr)
return nullptr;

if (key == node->key)
return &(node->value);
else if (key < node->key)
return searchNode(node->left, key);
else
return searchNode(node->right, key);
}

获取最小值节点

用于删除节点时找到中序后继。

1
2
3
4
5
6
7
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* getMinValueNode(AVLNode<KeyType, ValueType>* node) {
AVLNode<KeyType, ValueType>* current = node;
while (current->left != nullptr)
current = current->left;
return current;
}

删除节点

删除操作分为三种情况:删除节点是叶子节点、有一个子节点或有两个子节点。删除后,通过旋转保持树的平衡。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* deleteNode(AVLNode<KeyType, ValueType>* root, const KeyType& key) {
// 1. 执行标准的BST删除
if (root == nullptr)
return root;

if (key < root->key)
root->left = deleteNode(root->left, key);
else if (key > root->key)
root->right = deleteNode(root->right, key);
else {
// 节点有一个或没有子节点
if ((root->left == nullptr) || (root->right == nullptr)) {
AVLNode<KeyType, ValueType>* temp = root->left ? root->left : root->right;

// 没有子节点
if (temp == nullptr) {
temp = root;
root = nullptr;
}
else // 一个子节点
*root = *temp; // 复制内容

delete temp;
}
else {
// 节点有两个子节点,获取中序后继
AVLNode<KeyType, ValueType>* temp = getMinValueNode(root->right);
// 复制中序后继的内容到此节点
root->key = temp->key;
root->value = temp->value;
// 删除中序后继
root->right = deleteNode(root->right, temp->key);
}
}

// 如果树只有一个节点
if (root == nullptr)
return root;

// 2. 更新节点高度
root->height = 1 + std::max(getHeight(root->left), getHeight(root->right));

// 3. 获取平衡因子
int balance = getBalance(root);

// 4. 根据平衡因子进行旋转

// 左左情况
if (balance > 1 && getBalance(root->left) >= 0)
return rightRotate(root);

// 左右情况
if (balance > 1 && getBalance(root->left) < 0) {
root->left = leftRotate(root->left);
return rightRotate(root);
}

// 右右情况
if (balance < -1 && getBalance(root->right) <= 0)
return leftRotate(root);

// 右左情况
if (balance < -1 && getBalance(root->right) > 0) {
root->right = rightRotate(root->right);
return leftRotate(root);
}

return root;
}

4. 模板化的AVLMap类

现在,我们将所有模板化的函数集成到一个模板类AVLMap中。这个类将提供如下功能:

  • put(const KeyType& key, const ValueType& value):插入或更新键值对。
  • get(const KeyType& key):查找键对应的值,返回指向值的指针,如果键不存在则返回nullptr。
  • remove(const KeyType& key):删除指定键的键值对。
  • inorderTraversal():中序遍历,返回有序的键值对列表。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
template <typename KeyType, typename ValueType>
class AVLMap {
private:
AVLNode<KeyType, ValueType>* root;

// 中序遍历辅助函数
void inorderHelper(AVLNode<KeyType, ValueType>* node, std::vector<std::pair<KeyType, ValueType>>& res) const {
if (node != nullptr) {
inorderHelper(node->left, res);
res.emplace_back(node->key, node->value);
inorderHelper(node->right, res);
}
}

public:
AVLMap() : root(nullptr) {}

// 插入或更新键值对
void put(const KeyType& key, const ValueType& value) {
root = insertNode(root, key, value);
}

// 查找值,返回指向值的指针,如果键不存在则返回nullptr
ValueType* get(const KeyType& key) {
return searchNode(root, key);
}

// 删除键值对
void remove(const KeyType& key) {
root = deleteNode(root, key);
}

// 中序遍历,返回有序的键值对
std::vector<std::pair<KeyType, ValueType>> inorderTraversal() const {
std::vector<std::pair<KeyType, ValueType>> res;
inorderHelper(root, res);
return res;
}

// 析构函数,释放所有节点的内存
~AVLMap() {
// 使用后序遍历释放节点
std::function<void(AVLNode<KeyType, ValueType>*)> destroy = [&](AVLNode<KeyType, ValueType>* node) {
if (node) {
destroy(node->left);
destroy(node->right);
delete node;
}
};
destroy(root);
}
};

说明:

  • 模板参数

    :

    • KeyType:键的类型,需要支持operator<进行比较。
    • ValueType:值的类型,可以是任意类型。
  • 内存管理

    :

    • 析构函数使用后序遍历释放所有动态分配的节点,防止内存泄漏。
  • 异常安全

    :

    • 当前实现没有处理异常情况。如果需要更高的异常安全性,可以进一步增强代码,例如在插入过程中捕获异常并回滚操作。

5. 使用示例

下面的示例展示了如何使用模板化的AVLMap类,使用不同类型的键和值。

示例 1:使用int作为键,std::string作为值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <string>
#include <vector>

// 假设上面的AVLNode结构、辅助函数和AVLMap类已经定义

int main() {
AVLMap<int, std::string> avlMap;

// 插入键值对
avlMap.put(10, "十");
avlMap.put(20, "二十");
avlMap.put(30, "三十");
avlMap.put(40, "四十");
avlMap.put(50, "五十");
avlMap.put(25, "二十五");

// 中序遍历
std::vector<std::pair<int, std::string>> traversal = avlMap.inorderTraversal();
std::cout << "中序遍历: ";
for (const auto& pair : traversal) {
std::cout << "(" << pair.first << ", \"" << pair.second << "\") ";
}
std::cout << std::endl;

// 查找键
std::string* val = avlMap.get(20);
if (val)
std::cout << "获取键20的值: " << *val << std::endl;
else
std::cout << "键20不存在。" << std::endl;

val = avlMap.get(25);
if (val)
std::cout << "获取键25的值: " << *val << std::endl;
else
std::cout << "键25不存在。" << std::endl;

val = avlMap.get(60);
if (val)
std::cout << "获取键60的值: " << *val << std::endl;
else
std::cout << "键60不存在。" << std::endl;

// 删除键20
avlMap.remove(20);
std::cout << "删除键20后,中序遍历: ";
traversal = avlMap.inorderTraversal();
for (const auto& pair : traversal) {
std::cout << "(" << pair.first << ", \"" << pair.second << "\") ";
}
std::cout << std::endl;

return 0;
}

输出:

1
2
3
4
5
中序遍历: (10, "十") (20, "二十") (25, "二十五") (30, "三十") (40, "四十") (50, "五十") 
获取键20的值: 二十
获取键25的值: 二十五
键60不存在。
删除键20后,中序遍历: (10, "十") (25, "二十五") (30, "三十") (40, "四十") (50, "五十")

示例 2:使用std::string作为键,double作为值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <string>
#include <vector>

// 假设上面的AVLNode结构、辅助函数和AVLMap类已经定义

int main() {
AVLMap<std::string, double> avlMap;

// 插入键值对
avlMap.put("apple", 1.99);
avlMap.put("banana", 0.99);
avlMap.put("cherry", 2.99);
avlMap.put("date", 3.49);
avlMap.put("elderberry", 5.99);
avlMap.put("fig", 2.49);

// 中序遍历
std::vector<std::pair<std::string, double>> traversal = avlMap.inorderTraversal();
std::cout << "中序遍历: ";
for (const auto& pair : traversal) {
std::cout << "(\"" << pair.first << "\", " << pair.second << ") ";
}
std::cout << std::endl;

// 查找键
double* val = avlMap.get("banana");
if (val)
std::cout << "获取键\"banana\"的值: " << *val << std::endl;
else
std::cout << "键\"banana\"不存在。" << std::endl;

val = avlMap.get("fig");
if (val)
std::cout << "获取键\"fig\"的值: " << *val << std::endl;
else
std::cout << "键\"fig\"不存在。" << std::endl;

val = avlMap.get("grape");
if (val)
std::cout << "获取键\"grape\"的值: " << *val << std::endl;
else
std::cout << "键\"grape\"不存在。" << std::endl;

// 删除键"banana"
avlMap.remove("banana");
std::cout << "删除键\"banana\"后,中序遍历: ";
traversal = avlMap.inorderTraversal();
for (const auto& pair : traversal) {
std::cout << "(\"" << pair.first << "\", " << pair.second << ") ";
}
std::cout << std::endl;

return 0;
}

输出:

1
2
3
4
5
中序遍历: ("apple", 1.99) ("banana", 0.99) ("cherry", 2.99) ("date", 3.49) ("elderberry", 5.99) ("fig", 2.49) 
获取键"banana"的值: 0.99
获取键"fig"的值: 2.49
键"grape"不存在。
删除键"banana"后,中序遍历: ("apple", 1.99) ("cherry", 2.99) ("date", 3.49) ("elderberry", 5.99) ("fig", 2.49)

6. 完整的通用代码

以下是模板化的AVLMap的完整实现代码,包括所有辅助函数和类定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
#include <iostream>
#include <string>
#include <vector>
#include <algorithm> // 用于 std::max
#include <functional> // 用于 std::function

// 模板化的AVL树节点结构
template <typename KeyType, typename ValueType>
struct AVLNode {
KeyType key;
ValueType value;
int height;
AVLNode* left;
AVLNode* right;

AVLNode(const KeyType& k, const ValueType& val)
: key(k), value(val), height(1), left(nullptr), right(nullptr) {}
};

// 获取节点高度
template <typename KeyType, typename ValueType>
int getHeight(AVLNode<KeyType, ValueType>* node) {
if (node == nullptr)
return 0;
return node->height;
}

// 获取平衡因子
template <typename KeyType, typename ValueType>
int getBalance(AVLNode<KeyType, ValueType>* node) {
if (node == nullptr)
return 0;
return getHeight(node->left) - getHeight(node->right);
}

// 右旋转
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* rightRotate(AVLNode<KeyType, ValueType>* y) {
AVLNode<KeyType, ValueType>* x = y->left;
AVLNode<KeyType, ValueType>* T2 = x->right;

// 执行旋转
x->right = y;
y->left = T2;

// 更新高度
y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;
x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;

// 返回新的根
return x;
}

// 左旋转
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* leftRotate(AVLNode<KeyType, ValueType>* x) {
AVLNode<KeyType, ValueType>* y = x->right;
AVLNode<KeyType, ValueType>* T2 = y->left;

// 执行旋转
y->left = x;
x->right = T2;

// 更新高度
x->height = std::max(getHeight(x->left), getHeight(x->right)) + 1;
y->height = std::max(getHeight(y->left), getHeight(y->right)) + 1;

// 返回新的根
return y;
}

// 插入节点
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* insertNode(AVLNode<KeyType, ValueType>* node, const KeyType& key, const ValueType& value) {
// 1. 执行标准的BST插入
if (node == nullptr)
return new AVLNode<KeyType, ValueType>(key, value);

if (key < node->key)
node->left = insertNode(node->left, key, value);
else if (key > node->key)
node->right = insertNode(node->right, key, value);
else {
// 如果键已经存在,更新其值
node->value = value;
return node;
}

// 2. 更新节点高度
node->height = 1 + std::max(getHeight(node->left), getHeight(node->right));

// 3. 获取平衡因子
int balance = getBalance(node);

// 4. 根据平衡因子进行旋转

// 左左情况
if (balance > 1 && key < node->left->key)
return rightRotate(node);

// 右右情况
if (balance < -1 && key > node->right->key)
return leftRotate(node);

// 左右情况
if (balance > 1 && key > node->left->key) {
node->left = leftRotate(node->left);
return rightRotate(node);
}

// 右左情况
if (balance < -1 && key < node->right->key) {
node->right = rightRotate(node->right);
return leftRotate(node);
}

return node;
}

// 查找节点
template <typename KeyType, typename ValueType>
ValueType* searchNode(AVLNode<KeyType, ValueType>* node, const KeyType& key) {
if (node == nullptr)
return nullptr;

if (key == node->key)
return &(node->value);
else if (key < node->key)
return searchNode(node->left, key);
else
return searchNode(node->right, key);
}

// 获取最小值节点
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* getMinValueNode(AVLNode<KeyType, ValueType>* node) {
AVLNode<KeyType, ValueType>* current = node;
while (current->left != nullptr)
current = current->left;
return current;
}

// 删除节点
template <typename KeyType, typename ValueType>
AVLNode<KeyType, ValueType>* deleteNode(AVLNode<KeyType, ValueType>* root, const KeyType& key) {
// 1. 执行标准的BST删除
if (root == nullptr)
return root;

if (key < root->key)
root->left = deleteNode(root->left, key);
else if (key > root->key)
root->right = deleteNode(root->right, key);
else {
// 节点有一个或没有子节点
if ((root->left == nullptr) || (root->right == nullptr)) {
AVLNode<KeyType, ValueType>* temp = root->left ? root->left : root->right;

// 没有子节点
if (temp == nullptr) {
temp = root;
root = nullptr;
}
else // 一个子节点
*root = *temp; // 复制内容

delete temp;
}
else {
// 节点有两个子节点,获取中序后继
AVLNode<KeyType, ValueType>* temp = getMinValueNode(root->right);
// 复制中序后继的内容到此节点
root->key = temp->key;
root->value = temp->value;
// 删除中序后继
root->right = deleteNode(root->right, temp->key);
}
}

// 如果树只有一个节点
if (root == nullptr)
return root;

// 2. 更新节点高度
root->height = 1 + std::max(getHeight(root->left), getHeight(root->right));

// 3. 获取平衡因子
int balance = getBalance(root);

// 4. 根据平衡因子进行旋转

// 左左情况
if (balance > 1 && getBalance(root->left) >= 0)
return rightRotate(root);

// 左右情况
if (balance > 1 && getBalance(root->left) < 0) {
root->left = leftRotate(root->left);
return rightRotate(root);
}

// 右右情况
if (balance < -1 && getBalance(root->right) <= 0)
return leftRotate(root);

// 右左情况
if (balance < -1 && getBalance(root->right) > 0) {
root->right = rightRotate(root->right);
return leftRotate(root);
}

return root;
}

// 模板化的AVLMap类
template <typename KeyType, typename ValueType>
class AVLMap {
private:
AVLNode<KeyType, ValueType>* root;

// 中序遍历辅助函数
void inorderHelper(AVLNode<KeyType, ValueType>* node, std::vector<std::pair<KeyType, ValueType>>& res) const {
if (node != nullptr) {
inorderHelper(node->left, res);
res.emplace_back(node->key, node->value);
inorderHelper(node->right, res);
}
}

public:
AVLMap() : root(nullptr) {}

// 插入或更新键值对
void put(const KeyType& key, const ValueType& value) {
root = insertNode(root, key, value);
}

// 查找值,返回指向值的指针,如果键不存在则返回nullptr
ValueType* get(const KeyType& key) {
return searchNode(root, key);
}

// 删除键值对
void remove(const KeyType& key) {
root = deleteNode(root, key);
}

// 中序遍历,返回有序的键值对
std::vector<std::pair<KeyType, ValueType>> inorderTraversal() const {
std::vector<std::pair<KeyType, ValueType>> res;
inorderHelper(root, res);
return res;
}

// 析构函数,释放所有节点的内存
~AVLMap() {
// 使用后序遍历释放节点
std::function<void(AVLNode<KeyType, ValueType>*)> destroy = [&](AVLNode<KeyType, ValueType>* node) {
if (node) {
destroy(node->left);
destroy(node->right);
delete node;
}
};
destroy(root);
}
};

// 示例主函数
int main() {
// 示例 1:int 键,std::string 值
std::cout << "示例 1:int 键,std::string 值\n";
AVLMap<int, std::string> avlMap1;

// 插入键值对
avlMap1.put(10, "十");
avlMap1.put(20, "二十");
avlMap1.put(30, "三十");
avlMap1.put(40, "四十");
avlMap1.put(50, "五十");
avlMap1.put(25, "二十五");

// 中序遍历
std::vector<std::pair<int, std::string>> traversal1 = avlMap1.inorderTraversal();
std::cout << "中序遍历: ";
for (const auto& pair : traversal1) {
std::cout << "(" << pair.first << ", \"" << pair.second << "\") ";
}
std::cout << std::endl;

// 查找键
std::string* val1 = avlMap1.get(20);
if (val1)
std::cout << "获取键20的值: " << *val1 << std::endl;
else
std::cout << "键20不存在。" << std::endl;

val1 = avlMap1.get(25);
if (val1)
std::cout << "获取键25的值: " << *val1 << std::endl;
else
std::cout << "键25不存在。" << std::endl;

val1 = avlMap1.get(60);
if (val1)
std::cout << "获取键60的值: " << *val1 << std::endl;
else
std::cout << "键60不存在。" << std::endl;

// 删除键20
avlMap1.remove(20);
std::cout << "删除键20后,中序遍历: ";
traversal1 = avlMap1.inorderTraversal();
for (const auto& pair : traversal1) {
std::cout << "(" << pair.first << ", \"" << pair.second << "\") ";
}
std::cout << std::endl;

std::cout << "\n-----------------------------\n";

// 示例 2:std::string 键,double 值
std::cout << "示例 2:std::string 键,double 值\n";
AVLMap<std::string, double> avlMap2;

// 插入键值对
avlMap2.put("apple", 1.99);
avlMap2.put("banana", 0.99);
avlMap2.put("cherry", 2.99);
avlMap2.put("date", 3.49);
avlMap2.put("elderberry", 5.99);
avlMap2.put("fig", 2.49);

// 中序遍历
std::vector<std::pair<std::string, double>> traversal2 = avlMap2.inorderTraversal();
std::cout << "中序遍历: ";
for (const auto& pair : traversal2) {
std::cout << "(\"" << pair.first << "\", " << pair.second << ") ";
}
std::cout << std::endl;

// 查找键
double* val2 = avlMap2.get("banana");
if (val2)
std::cout << "获取键\"banana\"的值: " << *val2 << std::endl;
else
std::cout << "键\"banana\"不存在。" << std::endl;

val2 = avlMap2.get("fig");
if (val2)
std::cout << "获取键\"fig\"的值: " << *val2 << std::endl;
else
std::cout << "键\"fig\"不存在。" << std::endl;

val2 = avlMap2.get("grape");
if (val2)
std::cout << "获取键\"grape\"的值: " << *val2 << std::endl;
else
std::cout << "键\"grape\"不存在。" << std::endl;

// 删除键"banana"
avlMap2.remove("banana");
std::cout << "删除键\"banana\"后,中序遍历: ";
traversal2 = avlMap2.inorderTraversal();
for (const auto& pair : traversal2) {
std::cout << "(\"" << pair.first << "\", " << pair.second << ") ";
}
std::cout << std::endl;

return 0;
}

说明

  1. 平衡维护:在每次插入或删除后,都会更新节点的高度并计算平衡因子。如果某个节点的平衡因子超出了[-1, 1]范围,就需要通过旋转来重新平衡树。
  2. 查找操作:由于AVL树的高度被保持在O(log n),查找操作的时间复杂度为O(log n)。
  3. 更新操作:如果插入的键已经存在,则更新其对应的值。
  4. 遍历操作:中序遍历可以按键的顺序遍历所有键值对。
  5. 内存管理:确保在析构函数中正确释放所有动态分配的内存,避免内存泄漏。
  6. 泛型支持(可选):为了使AVLMap更加通用,可以将其模板化,以支持不同类型的键和值。例如:

编译与运行:

假设保存为 avlmap.cpp,使用以下命令编译并运行:

1
2
g++ -std=c++11 -o avlmap avlmap.cpp
./avlmap

预期输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
示例 1:int 键,std::string 值
中序遍历: (10, "十") (20, "二十") (25, "二十五") (30, "三十") (40, "四十") (50, "五十")
获取键20的值: 二十
获取键25的值: 二十五
键60不存在。
删除键20后,中序遍历: (10, "十") (25, "二十五") (30, "三十") (40, "四十") (50, "五十")

-----------------------------
示例 2:std::string 键,double 值
中序遍历: ("apple", 1.99) ("banana", 0.99) ("cherry", 2.99) ("date", 3.49) ("elderberry", 5.99) ("fig", 2.49)
获取键"banana"的值: 0.99
获取键"fig"的值: 2.49
键"grape"不存在。
删除键"banana"后,中序遍历: ("apple", 1.99) ("cherry", 2.99) ("date", 3.49) ("elderberry", 5.99) ("fig", 2.49)

7. 注意事项与扩展

1. 键类型的要求

为了使AVLMap正常工作,键类型KeyType必须支持以下操作:

  • 比较操作:必须定义operator<,因为AVL树依赖于它来维护排序。如果使用自定义类型作为键,请确保定义了operator<。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    struct CustomKey {
    int id;
    std::string name;

    bool operator<(const CustomKey& other) const {
    if (id != other.id)
    return id < other.id;
    return name < other.name;
    }
    };

2. 泛型支持与约束

在C++20之前,模板并不支持在模板参数上强制施加约束(需要依赖文档和用户理解)。从C++20起,可以使用概念(Concepts)来施加约束。

1
2
3
4
5
6
7
#include <concepts>

template <typename KeyType, typename ValueType>
requires std::totally_ordered<KeyType>
class AVLMap {
// 类定义
};

这样,编译器会在实例化模板时检查KeyType是否满足std::totally_ordered,即是否支持所有必要的比较操作。

3. 性能优化

  • 内存管理:当前实现使用递归进行插入和删除,如果树非常深,可能会导致栈溢出。可以考虑使用迭代方法或优化递归深度。
  • 缓存友好:使用自适应数据结构(如缓存友好的节点布局)可以提升性能。
  • 多线程支持:当前实现不是线程安全的。如果需要在多线程环境中使用,需要添加适当的同步机制。

4. 额外功能

根据需求,你可以为AVLMap添加更多功能:

  • 迭代器:实现输入迭代器、中序遍历迭代器,以便支持范围基(range-based)for循环。
  • 查找最小/最大键:提供方法findMin()和findMax()。
  • 前驱/后继查找:在树中查找给定键的前驱和后继节点。
  • 支持不同的平衡因子策略:例如,允许用户指定自定义的平衡策略。

5. 与标准库的比较

虽然自己实现AVLMap是一项很好的学习练习,但在实际生产环境中,建议使用C++标准库中已经高度优化和测试过的容器,如std::map(通常实现为红黑树)、std::unordered_map(哈希表)等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <map>

// 使用 std::map
int main() {
std::map<int, std::string> stdMap;

// 插入键值对
stdMap[10] = "十";
stdMap[20] = "二十";
stdMap[30] = "三十";

// 遍历
for (const auto& pair : stdMap) {
std::cout << "(" << pair.first << ", \"" << pair.second << "\") ";
}
std::cout << std::endl;

return 0;
}

std::map提供了与AVLMap类似的功能,并且经过了高度优化,适用于大多数应用场景。

红黑树

红黑树(Red-Black Tree)是一种自平衡的二叉搜索树,它通过对节点进行颜色标记(红色或黑色)并遵循特定的规则来保证树的平衡性,从而确保基本操作(如查找、插入、删除)的时间复杂度为 **O(log n)**。红黑树广泛应用于计算机科学中,例如在实现关联容器(如 std::map、std::set)时常用到。

1. 红黑树的五大性质

红黑树通过以下 五大性质 维持其平衡性:

  1. 节点颜色:每个节点要么是红色,要么是黑色。
  2. 根节点颜色:根节点是黑色。
  3. 叶子节点颜色:所有叶子节点(NIL 节点,即空节点)都是黑色的。这里的叶子节点不存储实际数据,仅作为树的终端。
  4. 红色节点限制:如果一个节点是红色的,则它的两个子节点都是黑色的。也就是说,红色节点不能有红色的子节点。
  5. 黑色平衡:从任意节点到其所有后代叶子节点的路径上,包含相同数量的黑色节点。这被称为每条路径上的黑色高度相同。

这些性质的意义

  • 性质1 和 性质2 确保节点颜色的基本规则,便于后续操作中进行颜色判断和调整。
  • 性质3 保证了所有实际节点的子节点(NIL 节点)的统一性,简化了操作逻辑。
  • 性质4 防止了连续的红色节点出现,避免导致过度不平衡。
  • 性质5 保证了树的平衡性,使得树的高度始终保持在 O(log n) 的范围内,从而确保基本操作的高效性。

这些性质共同作用,使得红黑树在最坏情况下也能保持较好的性能表现。


2. 红黑树的插入操作

插入操作是红黑树中常见的操作,与标准的二叉搜索树(BST)插入类似,但需要额外的步骤来维护红黑树的性质。

2.1 插入步骤概述

插入操作通常分为以下两个主要步骤:

  1. 标准二叉搜索树插入:根据键值比较,将新节点插入到合适的位置,初始颜色为红色。
  2. 插入后的修正(Insert Fixup):通过重新着色和旋转操作,恢复红黑树的五大性质。

2.2 插入后的修正(Insert Fixup)

插入一个红色节点可能会破坏红黑树的性质,尤其是性质4(红色节点不能连续)。为了修复这些潜在的冲突,需要进行颜色调整和旋转操作。

修正步骤:

插入修正的过程通常遵循以下规则(以下描述假设新插入的节点 z 是红色):

  1. 父节点为黑色:
    • 如果 z 的父节点是黑色的,那么插入操作不会破坏红黑树的性质,修正过程结束。
  2. 父节点为红色:
    • 情况1:z 的叔叔节点(即 z 的父节点的兄弟节点)也是红色。
      • 将父节点和叔叔节点重新着色为黑色。
      • 将祖父节点重新着色为红色。
      • 将 z 指向祖父节点,继续检查上层节点,防止高层的性质被破坏。
    • 情况2:z 的叔叔节点是黑色,且 z 是其父节点的右子节点。
      • 对 z 的父节点进行左旋转。
      • 将 z 指向其新的左子节点(即原父节点)。
    • 情况3:z 的叔叔节点是黑色,且 z 是其父节点的左子节点。
      • 将父节点重新着色为黑色。
      • 将祖父节点重新着色为红色。
      • 对祖父节点进行右旋转。

旋转操作的重要性

在修正过程中,旋转操作用于调整树的局部结构,使得红黑树的性质得以恢复。这些旋转包括左旋转和右旋转,在后续章节中将详细介绍。

插入修正的代码实现示例

以下是红黑树插入修正的一个简化版 C++ 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
template <typename Key, typename Value>
void RedBlackTree<Key, Value>::insertFixUp(RBTreeNode<Key, Value>* z) {
while (z->parent != nullptr && z->parent->color == RED) {
if (z->parent == z->parent->parent->left) {
RBTreeNode<Key, Value>* y = z->parent->parent->right; // 叔叔节点
if (y != nullptr && y->color == RED) {
// 情况1:叔叔为红色
z->parent->color = BLACK;
y->color = BLACK;
z->parent->parent->color = RED;
z = z->parent->parent;
} else {
if (z == z->parent->right) {
// 情况2:z为右子节点
z = z->parent;
leftRotate(z);
}
// 情况3:z为左子节点
z->parent->color = BLACK;
z->parent->parent->color = RED;
rightRotate(z->parent->parent);
}
} else {
// 情况对称:父节点是右子节点
RBTreeNode<Key, Value>* y = z->parent->parent->left; // 叔叔节点
if (y != nullptr && y->color == RED) {
// 情况1
z->parent->color = BLACK;
y->color = BLACK;
z->parent->parent->color = RED;
z = z->parent->parent;
} else {
if (z == z->parent->left) {
// 情况2
z = z->parent;
rightRotate(z);
}
// 情况3
z->parent->color = BLACK;
z->parent->parent->color = RED;
leftRotate(z->parent->parent);
}
}
}
root->color = BLACK; // 最终根节点必须为黑色
}

3. 红黑树的删除操作

删除操作同样重要且复杂,因为它可能破坏红黑树的多个性质。与插入类似,删除操作也需要两个主要步骤:

  1. 标准二叉搜索树删除:按照 BST 的规则删除节点。
  2. 删除后的修正(Delete Fixup):通过重新着色和旋转操作,恢复红黑树的性质。

3.1 删除步骤概述

删除操作分为以下步骤:

  1. 定位要删除的节点:
    • 如果要删除的节点 z 有两个子节点,则找到其中序后继节点 y(即 z 的右子树中的最小节点)。
    • 将 y 的值复制到 z,然后将删除目标转移到 y。此时 y 至多只有一个子节点。
  2. 删除节点:
    • 若 y 只有一个子节点 x(可能为 NIL),则用 x 替代 y 的位置。
    • 记录被删除节点的原颜色 y_original_color。
  3. 删除修正(仅当 y_original_color 为黑色时):
    • 因为删除一个黑色节点会影响路径上的黑色数量,需通过多次调整来恢复红黑树的性质。

3.2 删除后的修正(Delete Fixup)

删除后的修正较为复杂,涉及多种情况处理。以下为主要的修正步骤和可能遇到的情况:

修正步骤:

  1. 初始化:设 x 为替代被删除的节点的位置,x 可能为实际节点或 NIL 节点。

  2. 循环修正:

    • 当 x 不是根节点,且 x 的颜色为黑色,进入修正循环。
    • 判断 x 是其父节点的左子节点还是右子节点,并相应地设定兄弟节点 w。
  3. 处理不同情况:

    情况1:w 是红色的。

    • 将 w 重新着色为黑色。
    • 将 x 的父节点重新着色为红色。
    • 对 x 的父节点进行左旋转或右旋转,取决于是左子节点还是右子节点。
    • 更新 w,继续修正过程。

    情况2:w 是黑色,且 w 的两个子节点都是黑色。

    • 将 w 重新着色为红色。
    • 将 x 设为其父节点,继续修正。

    情况3:w 是黑色,且 w 的左子节点是红色,右子节点是黑色。

    • 将 w 的左子节点重新着色为黑色。
    • 将 w 重新着色为红色。
    • 对 w 进行右旋转或左旋转,取决于是左子节点还是右子节点。
    • 更新 w,进入情况4。

    情况4:w 是黑色,且 w 的右子节点是红色。

    • 将 w 的颜色设为 x 的父节点颜色。
    • 将 x 的父节点重新着色为黑色。
    • 将 w 的右子节点重新着色为黑色。
    • 对 x 的父节点进行左旋转或右旋转,取决于是左子节点还是右子节点。
    • 结束修正。
  4. 最终步骤:将 x 设为根节点,并将其颜色设为黑色,确保根节点的颜色为黑色。

删除修正的代码实现示例

由于删除修正涉及较多的情况,以下为一个简化版的红黑树删除修正的 C++ 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
template <typename Key, typename Value>
void RedBlackTree<Key, Value>::deleteFixUp(RBTreeNode<Key, Value>* x) {
while (x != root && (x == nullptr || x->color == BLACK)) {
if (x == x->parent->left) {
RBTreeNode<Key, Value>* w = x->parent->right; // 兄弟节点
if (w->color == RED) {
// 情况1
w->color = BLACK;
x->parent->color = RED;
leftRotate(x->parent);
w = x->parent->right;
}
if ((w->left == nullptr || w->left->color == BLACK) &&
(w->right == nullptr || w->right->color == BLACK)) {
// 情况2
w->color = RED;
x = x->parent;
} else {
if (w->right == nullptr || w->right->color == BLACK) {
// 情况3
if (w->left != nullptr)
w->left->color = BLACK;
w->color = RED;
rightRotate(w);
w = x->parent->right;
}
// 情况4
w->color = x->parent->color;
x->parent->color = BLACK;
if (w->right != nullptr)
w->right->color = BLACK;
leftRotate(x->parent);
x = root; // 修正完成
}
} else {
// 情况对称:x 是右子节点
RBTreeNode<Key, Value>* w = x->parent->left; // 兄弟节点
if (w->color == RED) {
// 情况1
w->color = BLACK;
x->parent->color = RED;
rightRotate(x->parent);
w = x->parent->left;
}
if ((w->right == nullptr || w->right->color == BLACK) &&
(w->left == nullptr || w->left->color == BLACK)) {
// 情况2
w->color = RED;
x = x->parent;
} else {
if (w->left == nullptr || w->left->color == BLACK) {
// 情况3
if (w->right != nullptr)
w->right->color = BLACK;
w->color = RED;
leftRotate(w);
w = x->parent->left;
}
// 情况4
w->color = x->parent->color;
x->parent->color = BLACK;
if (w->left != nullptr)
w->left->color = BLACK;
rightRotate(x->parent);
x = root; // 修正完成
}
}
}
if (x != nullptr)
x->color = BLACK;
}

4. 旋转操作详解

旋转操作是红黑树中用于重新平衡树的关键操作,包括左旋转和右旋转。旋转操作通过调整节点的父子关系,改变树的局部结构,从而保持红黑树的性质。

4.1 左旋转(Left Rotate)

左旋转围绕节点 x 进行,其目的是将 x 的右子节点 y 提升为 x 的父节点,x 变为 y 的左子节点,y 的左子节点 b 成为 x 的右子节点。

旋转前:

1
2
3
4
5
  x
/ \
a y
/ \
b c

旋转后:

1
2
3
4
5
    y
/ \
x c
/ \
a b

左旋转的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <typename Key, typename Value>
void RedBlackTree<Key, Value>::leftRotate(RBTreeNode<Key, Value>* x) {
RBTreeNode<Key, Value>* y = x->right;
x->right = y->left;
if (y->left != nullptr)
y->left->parent = x;

y->parent = x->parent;
if (x->parent == nullptr)
root = y;
else if (x == x->parent->left)
x->parent->left = y;
else
x->parent->right = y;

y->left = x;
x->parent = y;
}

4.2 右旋转(Right Rotate)

右旋转是 左旋转 的镜像操作,围绕节点 y 进行,其目的是将 y 的左子节点 x 提升为 y 的父节点,y 变为 x 的右子节点,x 的右子节点 b 成为 y 的左子节点。

旋转前:

1
2
3
4
5
    y
/ \
x c
/ \
a b

旋转后:

1
2
3
4
5
  x
/ \
a y
/ \
b c

右旋转的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <typename Key, typename Value>
void RedBlackTree<Key, Value>::rightRotate(RBTreeNode<Key, Value>* y) {
RBTreeNode<Key, Value>* x = y->left;
y->left = x->right;
if (x->right != nullptr)
x->right->parent = y;

x->parent = y->parent;
if (y->parent == nullptr)
root = x;
else if (y == y->parent->right)
y->parent->right = x;
else
y->parent->left = x;

x->right = y;
y->parent = x;
}

旋转操作的作用

通过旋转操作,可以改变树的高度和形状,确保红黑树的性质在插入和删除后得到维护。旋转不会破坏二叉搜索树的性质,仅改变节点之间的指向关系。


5.简化版红黑树实现

节点结构体

首先,我们定义红黑树节点的结构体:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <iostream>

enum Color { RED, BLACK };

template <typename Key, typename Value>
struct RBTreeNode {
Key key;
Value value;
Color color;
RBTreeNode* parent;
RBTreeNode* left;
RBTreeNode* right;

RBTreeNode(Key k, Value v)
: key(k), value(v), color(RED), parent(nullptr), left(nullptr), right(nullptr) {}
};

红黑树类

接下来,我们定义红黑树的主要类,包括插入、删除和遍历功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <iostream>

enum Color { RED, BLACK };

template <typename Key, typename Value>
struct RBTreeNode {
Key key;
Value value;
Color color;
RBTreeNode* parent;
RBTreeNode* left;
RBTreeNode* right;

RBTreeNode(Key k, Value v)
: key(k), value(v), color(RED), parent(nullptr), left(nullptr), right(nullptr) {}
};

template <typename Key, typename Value>
class RedBlackTree {
private:
RBTreeNode<Key, Value>* root;

void leftRotate(RBTreeNode<Key, Value>* x) {
RBTreeNode<Key, Value>* y = x->right;
x->right = y->left;
if (y->left != nullptr)
y->left->parent = x;

y->parent = x->parent;
if (x->parent == nullptr)
root = y;
else if (x == x->parent->left)
x->parent->left = y;
else
x->parent->right = y;

y->left = x;
x->parent = y;
}

void rightRotate(RBTreeNode<Key, Value>* y) {
RBTreeNode<Key, Value>* x = y->left;
y->left = x->right;
if (x->right != nullptr)
x->right->parent = y;

x->parent = y->parent;
if (y->parent == nullptr)
root = x;
else if (y == y->parent->right)
y->parent->right = x;
else
y->parent->left = x;

x->right = y;
y->parent = x;
}

void insertFixUp(RBTreeNode<Key, Value>* z) {
while (z->parent != nullptr && z->parent->color == RED) {
if (z->parent == z->parent->parent->left) {
RBTreeNode<Key, Value>* y = z->parent->parent->right; // 叔叔节点
if (y != nullptr && y->color == RED) {
// 情况1
z->parent->color = BLACK;
y->color = BLACK;
z->parent->parent->color = RED;
z = z->parent->parent;
} else {
if (z == z->parent->right) {
// 情况2
z = z->parent;
leftRotate(z);
}
// 情况3
z->parent->color = BLACK;
z->parent->parent->color = RED;
rightRotate(z->parent->parent);
}
} else {
// 父节点是右子节点,情况对称
RBTreeNode<Key, Value>* y = z->parent->parent->left; // 叔叔节点
if (y != nullptr && y->color == RED) {
// 情况1
z->parent->color = BLACK;
y->color = BLACK;
z->parent->parent->color = RED;
z = z->parent->parent;
} else {
if (z == z->parent->left) {
// 情况2
z = z->parent;
rightRotate(z);
}
// 情况3
z->parent->color = BLACK;
z->parent->parent->color = RED;
leftRotate(z->parent->parent);
}
}
}
root->color = BLACK;
}

void inorderHelper(RBTreeNode<Key, Value>* node) const {
if (node == nullptr) return;
inorderHelper(node->left);
std::cout << node->key << " ";
inorderHelper(node->right);
}

public:
RedBlackTree() : root(nullptr) {}

RBTreeNode<Key, Value>* getRoot() const { return root; }

void insert(const Key& key, const Value& value) {
RBTreeNode<Key, Value>* z = new RBTreeNode<Key, Value>(key, value);
RBTreeNode<Key, Value>* y = nullptr;
RBTreeNode<Key, Value>* x = root;

while (x != nullptr) {
y = x;
if (z->key < x->key)
x = x->left;
else
x = x->right;
}

z->parent = y;
if (y == nullptr)
root = z;
else if (z->key < y->key)
y->left = z;
else
y->right = z;

// 插入后修正红黑树性质
insertFixUp(z);
}

void inorderTraversal() const {
inorderHelper(root);
std::cout << std::endl;
}

// 为简化示例,删除操作未实现
// 完整实现需要包含 deleteFixUp 等步骤
};

简要说明

上述红黑树类包含以下主要功能:

  1. **插入操作 (insert)**:
    • 插入新的键值对,并调用 insertFixUp 进行修正,以保持红黑树的性质。
  2. **旋转操作 (leftRotate 和 rightRotate)**:
    • 通过旋转操作重新调整树的结构,确保树的平衡。
  3. **修正插入后的红黑树性质 (insertFixUp)**:
    • 根据红黑树的五大性质,通过重新着色和旋转来修正可能的违规情况。
  4. **中序遍历 (inorderTraversal)**:
    • 以中序遍历的方式输出树中的键,结果应为升序。

注意:为了简化示例,删除操作 (delete) 及其修正 (deleteFixUp) 未在此实现。如果需要完整的删除功能,请参考之前的详细解释或使用标准库中的实现。

6. 红黑树与其他平衡树的比较

红黑树并非唯一的自平衡二叉搜索树,其他常见的平衡树包括 AVL 树(Adelson-Velsky和Landis树)和 Splay 树。以下是红黑树与 AVL 树的比较:

红黑树 vs AVL 树

特性 红黑树 (Red-Black Tree) AVL 树 (AVL Tree)
平衡性 相对不严格,每个路径上的黑色节点相同。 更严格,任意节点的左右子树高度差不超过1。
插入/删除效率 较快,插入和删除操作较少的旋转,适用于频繁修改的场景。 较慢,插入和删除可能需要多次旋转,适用于查找操作多于修改的场景。
查找效率 O(log n) O(log n),常数因子更小,查找速度略快。
实现复杂度 相对简单,旋转操作较少。 实现较复杂,需严格维护高度平衡。
应用场景 操作频繁、需要快速插入和删除的场景。 查找操作频繁、插入和删除相对较少的场景。

选择依据

  • 红黑树更适用于需要频繁插入和删除操作,并且查找操作相对较多的场景,因为其插入和删除操作的调整成本较低。
  • AVL 树适用于查找操作极为频繁,而修改操作相对较少的场景,因为其高度更严格,查找效率更高。

7. 红黑树的应用场景

由于红黑树高效的查找、插入和删除性能,它在计算机科学中的多个领域都有广泛的应用:

  1. 标准库中的关联容器:
    • C++ 标准库中的 std::map 和 std::set 通常基于红黑树实现。
    • Java 的 TreeMap 和 TreeSet 也是基于红黑树。
  2. 操作系统:
    • Linux 内核中的调度器和虚拟内存管理使用红黑树来管理进程和内存资源。
  3. 数据库系统:
    • 一些数据库索引结构使用红黑树来提高查询效率。
  4. 编译器设计:
    • 语法分析树和符号表管理中可能使用红黑树来高效存储和查找符号。

手写双端队列

Posted on 2024-12-27 | In 零基础C++

1. 双端队列 (Deque) 概述

双端队列(Double-Ended Queue,简称 deque)是一种允许在其两端进行高效插入和删除操作的数据结构。与普通队列(只允许在一端插入和另一端删除)相比,双端队列更为灵活。

C++ 标准库中已经提供了 std::deque,但通过自行实现一个双端队列,可以更好地理解其内部机制和迭代器的工作原理。

2. 实现思路

为了实现一个高效的双端队列,我们需要考虑以下几点:

  1. 动态数组:使用动态数组(如环形缓冲区)来存储元素,以便支持在两端进行常数时间的插入和删除。
  2. 头尾指针:维护头部和尾部的索引,以便快速访问两端。
  3. 自动扩展:当容量不足时,自动调整内部缓冲区的大小。
  4. 迭代器支持:定义一个迭代器类,允许用户使用像 begin() 和 end() 这样的函数进行遍历。

接下来,我们将一步步实现这些功能。

3. 详细实现

3.1 内部数据结构

我们将使用一个动态分配的数组作为内部缓冲区,并通过头尾索引来管理队列的前后端。为了支持在两端高效插入和删除,我们将采用环形缓冲区的概念,即当索引达到数组的末端时,自动回绕到数组的开头。

3.2 Deque 类

下面是 Deque 类的基本结构和关键成员:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#include <iostream>
#include <stdexcept>
#include <iterator>

template <typename T>
class Deque {
private:
T* buffer; // 内部缓冲区
size_t capacity; // 缓冲区容量
size_t front_idx; // 头部索引
size_t back_idx; // 尾部索引
size_t count; // 当前元素数量

// 调整容量
void resize(size_t new_capacity) {
T* new_buffer = new T[new_capacity];
// 重新排列元素
for (size_t i = 0; i < count; ++i) {
new_buffer[i] = buffer[(front_idx + i) % capacity];
}
delete[] buffer;
buffer = new_buffer;
capacity = new_capacity;
front_idx = 0;
back_idx = count;
}

public:
// 构造函数
Deque(size_t initial_capacity = 8)
: capacity(initial_capacity), front_idx(0), back_idx(0), count(0) {
buffer = new T[capacity];
}

// 析构函数
~Deque() {
delete[] buffer;
}

// 复制构造函数和赋值运算符(省略,为简洁起见)

// 检查是否为空
bool empty() const {
return count == 0;
}

// 获取大小
size_t size() const {
return count;
}

// 在前面插入元素
void push_front(const T& value) {
if (count == capacity) {
resize(capacity * 2);
}
front_idx = (front_idx == 0) ? capacity - 1 : front_idx - 1;
buffer[front_idx] = value;
++count;
}

// 在后面插入元素
void push_back(const T& value) {
if (count == capacity) {
resize(capacity * 2);
}
buffer[back_idx] = value;
back_idx = (back_idx + 1) % capacity;
++count;
}

// 从前面删除元素
void pop_front() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
front_idx = (front_idx + 1) % capacity;
--count;
}

// 从后面删除元素
void pop_back() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
back_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
--count;
}

// 获取前端元素
T& front() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
return buffer[front_idx];
}

const T& front() const {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
return buffer[front_idx];
}

// 获取后端元素
T& back() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
size_t last_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
return buffer[last_idx];
}

const T& back() const {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
size_t last_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
return buffer[last_idx];
}

// 迭代器类将放在这里(见下一部分)

// 迭代器类定义
class Iterator {
private:
Deque<T>* deque_ptr;
size_t index;
size_t pos;

public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T&;

Iterator(Deque<T>* deque, size_t position)
: deque_ptr(deque), pos(position) {}

// 解引用操作
reference operator*() const {
size_t real_idx = (deque_ptr->front_idx + pos) % deque_ptr->capacity;
return deque_ptr->buffer[real_idx];
}

pointer operator->() const {
size_t real_idx = (deque_ptr->front_idx + pos) % deque_ptr->capacity;
return &(deque_ptr->buffer[real_idx]);
}

// 前置递增
Iterator& operator++() {
++pos;
return *this;
}

// 后置递增
Iterator operator++(int) {
Iterator temp = *this;
++pos;
return temp;
}

// 前置递减
Iterator& operator--() {
--pos;
return *this;
}

// 后置递减
Iterator operator--(int) {
Iterator temp = *this;
--pos;
return temp;
}

// 比较操作
bool operator==(const Iterator& other) const {
return (deque_ptr == other.deque_ptr) && (pos == other.pos);
}

bool operator!=(const Iterator& other) const {
return !(*this == other);
}
};

// 获取 begin 迭代器
Iterator begin() {
return Iterator(this, 0);
}

// 获取 end 迭代器
Iterator end() {
return Iterator(this, count);
}
};

3.3 迭代器类

在上面的 Deque 类中,我们定义了一个嵌套的 Iterator 类。这个迭代器支持前向和后向遍历,并且可以与标准的 C++ 迭代器兼容。

关键点解释:

  1. 成员变量:
    • deque_ptr:指向包含此迭代器的 Deque 实例。
    • pos:相对于队列头部的位置。
  2. 重载运算符:
    • operator* 和 operator->:用于访问当前元素。
    • operator++ 和 operator--:前置和后置递增和递减,用于移动迭代器。
    • operator== 和 operator!=:用于比较两个迭代器是否相同。
  3. 注意事项:
    • 迭代器并不管理元素的生命周期,只是提供遍历接口。
    • 迭代器的有效性依赖于队列的修改操作(如插入和删除)。在实际应用中,需要注意迭代器失效的问题。

4. 使用示例

下面是一个使用上述 Deque 类及其迭代器的示例程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <iostream>
#include <string>

// 假设 Deque 类已经定义在这里

int main() {
Deque<std::string> dq;

// 在后面插入元素
dq.push_back("Apple");
dq.push_back("Banana");
dq.push_back("Cherry");

// 在前面插入元素
dq.push_front("Date");
dq.push_front("Elderberry");

// 显示队列大小
std::cout << "Deque 大小: " << dq.size() << std::endl;

// 使用迭代器进行遍历
std::cout << "Deque 元素: ";
for (auto it = dq.begin(); it != dq.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 访问前端和后端元素
std::cout << "前端元素: " << dq.front() << std::endl;
std::cout << "后端元素: " << dq.back() << std::endl;

// 删除元素
dq.pop_front();
dq.pop_back();

// 再次遍历
std::cout << "删除元素后的 Deque: ";
for (auto it = dq.begin(); it != dq.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

return 0;
}

预期输出

1
2
3
4
5
Deque 大小: 5
Deque 元素: Elderberry Date Apple Banana Cherry
前端元素: Elderberry
后端元素: Cherry
删除元素后的 Deque: Date Apple Banana

解释:

  1. 插入操作:
    • 使用 push_back 在队列的后端插入 “Apple”, “Banana”, “Cherry”。
    • 使用 push_front 在队列的前端插入 “Date”, “Elderberry”。
    • 最终队列顺序为:Elderberry, Date, Apple, Banana, Cherry
  2. 遍历操作:
    • 使用迭代器从 begin() 到 end() 遍历并打印所有元素。
  3. 访问元素:
    • 使用 front() 获取队列前端的元素。
    • 使用 back() 获取队列后端的元素。
  4. 删除操作:
    • 使用 pop_front 删除前端元素(”Elderberry”)。
    • 使用 pop_back 删除后端元素(”Cherry”)。
    • 删除后,队列顺序为:Date, Apple, Banana

5. 完整代码

以下是完整的 Deque 类及其使用示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#include <iostream>
#include <stdexcept>
#include <iterator>

// Deque 类定义
template <typename T>
class Deque {
private:
T* buffer; // 内部缓冲区
size_t capacity; // 缓冲区容量
size_t front_idx; // 头部索引
size_t back_idx; // 尾部索引
size_t count; // 当前元素数量

// 调整容量
void resize(size_t new_capacity) {
T* new_buffer = new T[new_capacity];
// 重新排列元素
for (size_t i = 0; i < count; ++i) {
new_buffer[i] = buffer[(front_idx + i) % capacity];
}
delete[] buffer;
buffer = new_buffer;
capacity = new_capacity;
front_idx = 0;
back_idx = count;
}

public:
// 构造函数
Deque(size_t initial_capacity = 8)
: capacity(initial_capacity), front_idx(0), back_idx(0), count(0) {
buffer = new T[capacity];
}

// 析构函数
~Deque() {
delete[] buffer;
}

// 禁用复制构造函数和赋值运算符(为了简洁,可根据需要实现)
Deque(const Deque& other) = delete;
Deque& operator=(const Deque& other) = delete;

// 检查是否为空
bool empty() const {
return count == 0;
}

// 获取大小
size_t size() const {
return count;
}

// 在前面插入元素
void push_front(const T& value) {
if (count == capacity) {
resize(capacity * 2);
}
front_idx = (front_idx == 0) ? capacity - 1 : front_idx - 1;
buffer[front_idx] = value;
++count;
}

// 在后面插入元素
void push_back(const T& value) {
if (count == capacity) {
resize(capacity * 2);
}
buffer[back_idx] = value;
back_idx = (back_idx + 1) % capacity;
++count;
}

// 从前面删除元素
void pop_front() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
front_idx = (front_idx + 1) % capacity;
--count;
}

// 从后面删除元素
void pop_back() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
back_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
--count;
}

// 获取前端元素
T& front() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
return buffer[front_idx];
}

const T& front() const {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
return buffer[front_idx];
}

// 获取后端元素
T& back() {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
size_t last_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
return buffer[last_idx];
}

const T& back() const {
if (empty()) {
throw std::out_of_range("Deque is empty");
}
size_t last_idx = (back_idx == 0) ? capacity - 1 : back_idx - 1;
return buffer[last_idx];
}

// 迭代器类定义
class Iterator {
private:
Deque<T>* deque_ptr;
size_t pos;

public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T&;

Iterator(Deque<T>* deque, size_t position)
: deque_ptr(deque), pos(position) {}

// 解引用操作
reference operator*() const {
size_t real_idx = (deque_ptr->front_idx + pos) % deque_ptr->capacity;
return deque_ptr->buffer[real_idx];
}

pointer operator->() const {
size_t real_idx = (deque_ptr->front_idx + pos) % deque_ptr->capacity;
return &(deque_ptr->buffer[real_idx]);
}

// 前置递增
Iterator& operator++() {
++pos;
return *this;
}

// 后置递增
Iterator operator++(int) {
Iterator temp = *this;
++pos;
return temp;
}

// 前置递减
Iterator& operator--() {
--pos;
return *this;
}

// 后置递减
Iterator operator--(int) {
Iterator temp = *this;
--pos;
return temp;
}

// 比较操作
bool operator==(const Iterator& other) const {
return (deque_ptr == other.deque_ptr) && (pos == other.pos);
}

bool operator!=(const Iterator& other) const {
return !(*this == other);
}
};

// 获取 begin 迭代器
Iterator begin() {
return Iterator(this, 0);
}

// 获取 end 迭代器
Iterator end() {
return Iterator(this, count);
}
};

// 使用示例
int main() {
Deque<std::string> dq;

// 在后面插入元素
dq.push_back("Apple");
dq.push_back("Banana");
dq.push_back("Cherry");

// 在前面插入元素
dq.push_front("Date");
dq.push_front("Elderberry");

// 显示队列大小
std::cout << "Deque 大小: " << dq.size() << std::endl;

// 使用迭代器进行遍历
std::cout << "Deque 元素: ";
for (auto it = dq.begin(); it != dq.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 访问前端和后端元素
std::cout << "前端元素: " << dq.front() << std::endl;
std::cout << "后端元素: " << dq.back() << std::endl;

// 删除元素
dq.pop_front();
dq.pop_back();

// 再次遍历
std::cout << "删除元素后的 Deque: ";
for (auto it = dq.begin(); it != dq.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

return 0;
}

编译和运行

保存上述代码到一个文件,例如 DequeWithIterator.cpp,然后使用 C++ 编译器进行编译和运行:

1
2
g++ -std=c++11 -o DequeWithIterator DequeWithIterator.cpp
./DequeWithIterator

预期输出

1
2
3
4
5
Deque 大小: 5
Deque 元素: Elderberry Date Apple Banana Cherry
前端元素: Elderberry
后端元素: Cherry
删除元素后的 Deque: Date Apple Banana

6. 总结

通过上述步骤,我们成功实现了一个支持双端插入和删除的双端队列(deque),并添加了迭代器支持,使其能够与标准的 C++ 迭代器接口兼容。这个实现包含了以下关键点:

  1. 内部缓冲区管理:
    • 使用动态数组并采用环形缓冲区的方式,支持高效的双端操作。
    • 自动调整缓冲区的容量,确保在元素数量增加时仍能保持高效。
  2. 迭代器实现:
    • 定义了一个嵌套的 Iterator 类,支持前向和后向遍历。
    • 重载了必要的运算符(如 *, ->, ++, --, ==, !=),以实现与标准迭代器的兼容。
  3. 基本操作:
    • push_front 和 push_back:分别在队列的前端和后端插入元素。
    • pop_front 和 pop_back:分别从队列的前端和后端删除元素。
    • front 和 back:访问队列的前端和后端元素。

零基础C++(25) stl几种容器详解

Posted on 2024-12-22 | In 零基础C++

简介

C++的标准模板库(STL)提供了多种通用容器,用于存储和管理数据。这些容器各有特点,适用于不同的应用场景。理解每种容器的用法和内部实现原理,对于编写高效且可维护的代码至关重要。本教案将详细介绍几种常用的STL容器,包括vector、list、deque、map、unordered_map、set、unordered_set以及容器适配器如stack、queue和priority_queue。


vector:动态数组

用法

vector是STL中最常用的序列容器之一,提供了动态大小的数组功能。它支持随机访问,允许在末尾高效地添加和删除元素。

内部实现原理

vector在内部使用动态数组(通常是连续的内存块)来存储元素。当需要扩展容量时,它会分配一块更大的内存,将现有元素复制到新内存中,然后释放旧内存。这种策略在平均情况下保证了push_back的常数时间复杂度。

性能特性

  • 随机访问:支持常数时间的随机访问(O(1))。
  • 末尾插入/删除:push_back和pop_back操作在摊销分析下是常数时间(O(1))。
  • 中间插入/删除:在中间位置插入或删除元素需要移动后续元素,时间复杂度为线性时间(O(n))。

应用场景

  • 需要频繁随机访问元素。
  • 主要在容器末尾进行插入和删除操作。
  • 当容器大小不需要频繁调整(避免频繁的内存重新分配)。

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <iostream>
#include <vector>

int main() {
// 创建一个空的整数vector
std::vector<int> numbers;

// 向vector末尾添加元素
numbers.push_back(10);
numbers.push_back(20);
numbers.push_back(30);

// 通过索引访问元素
std::cout << "第一个元素: " << numbers[0] << std::endl;

// 遍历vector
std::cout << "所有元素: ";
for(auto it = numbers.begin(); it != numbers.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 删除最后一个元素
numbers.pop_back();

// 打印删除后的vector
std::cout << "删除最后一个元素后: ";
for(auto num : numbers) {
std::cout << num << " ";
}
std::cout << std::endl;

return 0;
}

输出

1
2
3
第一个元素: 10
所有元素: 10 20 30
删除最后一个元素后: 10 20

list:双向链表

用法

list是一个实现了双向链表的数据结构,适合在容器中间频繁插入和删除元素。与vector不同,list不支持随机访问,但在任何位置的插入和删除操作都是常数时间。

内部实现原理

list在内部使用双向链表,每个元素包含指向前一个和后一个元素的指针。这使得在已知位置插入或删除元素时,无需移动其他元素,只需更新指针即可。

性能特性

  • 随机访问:不支持随机访问,访问第n个元素需要线性时间(O(n))。
  • 中间插入/删除:已知位置的插入和删除操作是常数时间(O(1))。
  • 遍历:顺序遍历,适合需要频繁遍历但不需要随机访问的场景。

应用场景

  • 需要在容器中间频繁插入或删除元素。
  • 不需要进行随机访问。
  • 对内存的局部性要求不高(链表元素在内存中不连续)。

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <iostream>
#include <list>

int main() {
// 创建一个空的整数list
std::list<int> numbers;

// 向list末尾添加元素
numbers.push_back(100);
numbers.push_back(200);
numbers.push_back(300);

// 向list前端添加元素
numbers.push_front(50);

// 遍历list
std::cout << "所有元素: ";
for(auto it = numbers.begin(); it != numbers.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 插入元素
auto it = numbers.begin();
++it; // 指向第二个元素
numbers.insert(it, 150);

// 打印插入后的list
std::cout << "插入元素后: ";
for(auto num : numbers) {
std::cout << num << " ";
}
std::cout << std::endl;

// 删除元素
numbers.remove(200);

// 打印删除后的list
std::cout << "删除元素后: ";
for(auto num : numbers) {
std::cout << num << " ";
}
std::cout << std::endl;

return 0;
}

输出

1
2
3
所有元素: 50 100 200 300 
插入元素后: 50 150 100 200 300
删除元素后: 50 150 100 300

模拟实现一个简化版的 List

为了更好地理解 std::list 的内部工作原理,我们可以尝试模拟实现一个简化版的双向链表。下面将逐步介绍如何设计和实现这个 List 类。

类设计

我们的 List 类将包含以下组件:

  1. 节点结构体(Node):表示链表的每个节点。
  2. 迭代器类(Iterator):允许用户遍历链表。
  3. List 类:管理链表的基本操作,如插入、删除和遍历。

节点结构体

每个节点包含数据域和前后指针:

1
2
3
4
5
6
7
8
template<typename T>
struct Node {
T data;
Node* prev;
Node* next;

Node(const T& value = T()) : data(value), prev(nullptr), next(nullptr) {}
};

迭代器实现

为了实现双向迭代器,我们需要定义一个 Iterator 类,支持 ++ 和 -- 操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
template<typename T>
class List;

template<typename T>
class Iterator {
public:
using self_type = Iterator<T>;
using value_type = T;
using reference = T&;
using pointer = T*;
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;

Iterator(Node<T>* ptr = nullptr) : node_ptr(ptr) {}

// Dereference operator
reference operator*() const { return node_ptr->data; }

// Arrow operator
pointer operator->() const { return &(node_ptr->data); }

// Pre-increment
self_type& operator++() {
if (node_ptr) node_ptr = node_ptr->next;
return *this;
}

// Post-increment
self_type operator++(int) {
self_type temp = *this;
++(*this);
return temp;
}

// Pre-decrement
self_type& operator--() {
if (node_ptr) node_ptr = node_ptr->prev;
return *this;
}

// Post-decrement
self_type operator--(int) {
self_type temp = *this;
--(*this);
return temp;
}

// Equality comparison
bool operator==(const self_type& other) const {
return node_ptr == other.node_ptr;
}

// Inequality comparison
bool operator!=(const self_type& other) const {
return node_ptr != other.node_ptr;
}

private:
Node<T>* node_ptr;

friend class List<T>;
};

List 类

List 类提供链表的基本功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <iostream>

template<typename T>
class List {
public:
using iterator = Iterator<T>;
using const_iterator = Iterator<T>;

// 构造函数
List() {
head = new Node<T>(); // 哨兵节点
tail = new Node<T>(); // 哨兵节点
head->next = tail;
tail->prev = head;
}

// 析构函数
~List() {
clear();
delete head;
delete tail;
}

// 禁止拷贝构造和赋值操作(简化实现)
List(const List& other) = delete;
List& operator=(const List& other) = delete;

// 插入元素到迭代器位置之前
iterator insert(iterator pos, const T& value) {
Node<T>* current = pos.node_ptr;
Node<T>* new_node = new Node<T>(value);

Node<T>* prev_node = current->prev;

new_node->next = current;
new_node->prev = prev_node;

prev_node->next = new_node;
current->prev = new_node;

return iterator(new_node);
}

// 删除迭代器指向的元素
iterator erase(iterator pos) {
Node<T>* current = pos.node_ptr;
if (current == head || current == tail) {
// 不能删除哨兵节点
return pos;
}

Node<T>* prev_node = current->prev;
Node<T>* next_node = current->next;

prev_node->next = next_node;
next_node->prev = prev_node;

delete current;

return iterator(next_node);
}

// 在头部插入元素
void push_front(const T& value) {
insert(begin(), value);
}

// 在尾部插入元素
void push_back(const T& value) {
insert(end(), value);
}

// 在头部删除元素
void pop_front() {
if (!empty()) {
erase(begin());
}
}

// 在尾部删除元素
void pop_back() {
if (!empty()) {
iterator temp = end();
--temp;
erase(temp);
}
}

// 获取头元素引用
T& front() {
return head->next->data;
}

// 获取尾元素引用
T& back() {
return tail->prev->data;
}

// 判断是否为空
bool empty() const {
return head->next == tail;
}

// 获取链表大小(O(n)复杂度)
size_t size() const {
size_t count = 0;
for(auto it = begin(); it != end(); ++it) {
++count;
}
return count;
}

// 清空链表
void clear() {
Node<T>* current = head->next;
while(current != tail) {
Node<T>* temp = current;
current = current->next;
delete temp;
}
head->next = tail;
tail->prev = head;
}

// 获取开始迭代器
iterator begin() {
return iterator(head->next);
}

// 获取结束迭代器
iterator end() {
return iterator(tail);
}

// 打印链表(辅助函数)
void print() const {
Node<T>* current = head->next;
while(current != tail) {
std::cout << current->data << " ";
current = current->next;
}
std::cout << std::endl;
}

private:
Node<T>* head; // 头哨兵
Node<T>* tail; // 尾哨兵
};

完整代码示例

下面是一个完整的示例,包括创建 List 对象,进行各种操作,并打印结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#include <iostream>

// 节点结构体
template<typename T>
struct Node {
T data;
Node* prev;
Node* next;

Node(const T& value = T()) : data(value), prev(nullptr), next(nullptr) {}
};

// 迭代器类
template<typename T>
class Iterator {
public:
using self_type = Iterator<T>;
using value_type = T;
using reference = T&;
using pointer = T*;
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;

Iterator(Node<T>* ptr = nullptr) : node_ptr(ptr) {}

// Dereference operator
reference operator*() const { return node_ptr->data; }

// Arrow operator
pointer operator->() const { return &(node_ptr->data); }

// Pre-increment
self_type& operator++() {
if (node_ptr) node_ptr = node_ptr->next;
return *this;
}

// Post-increment
self_type operator++(int) {
self_type temp = *this;
++(*this);
return temp;
}

// Pre-decrement
self_type& operator--() {
if (node_ptr) node_ptr = node_ptr->prev;
return *this;
}

// Post-decrement
self_type operator--(int) {
self_type temp = *this;
--(*this);
return temp;
}

// Equality comparison
bool operator==(const self_type& other) const {
return node_ptr == other.node_ptr;
}

// Inequality comparison
bool operator!=(const self_type& other) const {
return node_ptr != other.node_ptr;
}

private:
Node<T>* node_ptr;

friend class List<T>;
};

// List 类
template<typename T>
class List {
public:
using iterator = Iterator<T>;
using const_iterator = Iterator<T>;

// 构造函数
List() {
head = new Node<T>(); // 头哨兵
tail = new Node<T>(); // 尾哨兵
head->next = tail;
tail->prev = head;
}

// 析构函数
~List() {
clear();
delete head;
delete tail;
}

// 禁止拷贝构造和赋值操作(简化实现)
List(const List& other) = delete;
List& operator=(const List& other) = delete;

// 插入元素到迭代器位置之前
iterator insert(iterator pos, const T& value) {
Node<T>* current = pos.node_ptr;
Node<T>* new_node = new Node<T>(value);

Node<T>* prev_node = current->prev;

new_node->next = current;
new_node->prev = prev_node;

prev_node->next = new_node;
current->prev = new_node;

return iterator(new_node);
}

// 删除迭代器指向的元素
iterator erase(iterator pos) {
Node<T>* current = pos.node_ptr;
if (current == head || current == tail) {
// 不能删除哨兵节点
return pos;
}

Node<T>* prev_node = current->prev;
Node<T>* next_node = current->next;

prev_node->next = next_node;
next_node->prev = prev_node;

delete current;

return iterator(next_node);
}

// 在头部插入元素
void push_front(const T& value) {
insert(begin(), value);
}

// 在尾部插入元素
void push_back(const T& value) {
insert(end(), value);
}

// 在头部删除元素
void pop_front() {
if (!empty()) {
erase(begin());
}
}

// 在尾部删除元素
void pop_back() {
if (!empty()) {
iterator temp = end();
--temp;
erase(temp);
}
}

// 获取头元素引用
T& front() {
return head->next->data;
}

// 获取尾元素引用
T& back() {
return tail->prev->data;
}

// 判断是否为空
bool empty() const {
return head->next == tail;
}

// 获取链表大小(O(n)复杂度)
size_t size() const {
size_t count = 0;
for(auto it = begin(); it != end(); ++it) {
++count;
}
return count;
}

// 清空链表
void clear() {
Node<T>* current = head->next;
while(current != tail) {
Node<T>* temp = current;
current = current->next;
delete temp;
}
head->next = tail;
tail->prev = head;
}

// 获取开始迭代器
iterator begin() {
return iterator(head->next);
}

// 获取结束迭代器
iterator end() {
return iterator(tail);
}

// 打印链表(辅助函数)
void print() const {
Node<T>* current = head->next;
while(current != tail) {
std::cout << current->data << " ";
current = current->next;
}
std::cout << std::endl;
}

private:
Node<T>* head; // 头哨兵
Node<T>* tail; // 尾哨兵
};

// 测试代码
int main() {
List<int> lst;

// 插入元素
lst.push_back(10); // 链表: 10
lst.push_front(5); // 链表: 5, 10
lst.push_back(15); // 链表: 5, 10, 15
lst.insert(++lst.begin(), 7); // 链表: 5, 7, 10, 15

// 打印链表
std::cout << "链表内容: ";
lst.print(); // 输出: 5 7 10 15

// 删除元素
lst.pop_front(); // 链表: 7, 10, 15
lst.pop_back(); // 链表: 7, 10

// 打印链表
std::cout << "删除头尾后链表内容: ";
lst.print(); // 输出: 7 10

// 插入和删除
auto it = lst.begin();
lst.insert(it, 3); // 链表: 3, 7, 10
lst.erase(++it); // 链表: 3, 10

// 打印链表
std::cout << "插入和删除后链表内容: ";
lst.print(); // 输出: 3 10

// 清空链表
lst.clear();
std::cout << "清空后,链表是否为空: " << (lst.empty() ? "是" : "否") << std::endl;

return 0;
}

代码解释

  1. **节点结构体 Node**:包含数据域 data,前驱指针 prev 和后继指针 next。

  2. **迭代器类 Iterator**:

    • 构造函数:接受一个 Node<T>* 指针。

    • 重载操作符

      :

      • * 和 -> 用于访问节点数据。
      • ++ 和 -- 支持前向和后向遍历。
      • == 和 != 用于比较迭代器。
  3. List 类:

    • 成员变量

      :

      • head 和 tail 是头尾哨兵节点。
    • 构造函数:初始化头尾哨兵,并将它们互相连接。

    • 析构函数:清空链表并删除哨兵节点。

    • **insert**:在指定位置前插入新节点。

    • **erase**:删除指定位置的节点。

    • **push_front 和 push_back**:分别在头部和尾部插入元素。

    • **pop_front 和 pop_back**:分别删除头部和尾部元素。

    • **front 和 back**:访问头尾元素。

    • **empty 和 size**:检查链表是否为空和获取链表大小。

    • **clear**:清空链表。

    • **begin 和 end**:返回开始和结束迭代器。

    • **print**:辅助函数,用于打印链表内容。

  4. 测试代码:创建 List<int> 对象,并执行一系列的插入、删除和遍历操作,验证 List 类的功能。

编译和运行

保存上述代码到一个名为 List.cpp 的文件中,然后使用以下命令编译和运行:

1
2
g++ -std=c++11 -o List List.cpp
./List

输出结果:

1
2
3
4
链表内容: 5 7 10 15 
删除头尾后链表内容: 7 10
插入和删除后链表内容: 3 10
清空后,链表是否为空: 是

迭代器分类

1. 迭代器(Iterator)简介

在 C++ 中,迭代器 是一种用于遍历容器(如 std::vector、std::list 等)元素的对象。它们提供了类似指针的接口,使得算法可以独立于具体的容器而工作。迭代器的设计允许算法以统一的方式处理不同类型的容器。


2. 迭代器类别(Iterator Categories)

为了使不同类型的迭代器能够支持不同的操作,C++ 标准库将迭代器分为以下几种类别,每种类别支持的操作能力逐级增强:

  1. 输入迭代器(Input Iterator)
  2. 输出迭代器(Output Iterator)
  3. 前向迭代器(Forward Iterator)
  4. 双向迭代器(Bidirectional Iterator)
  5. 随机访问迭代器(Random Access Iterator)
  6. 无效迭代器(Contiguous Iterator)(C++20 引入)

每个类别都继承自前一个类别,具备更强的功能。例如,双向迭代器不仅支持前向迭代器的所有操作,还支持反向迭代(即可以向后移动)。

主要迭代器类别及其特性

类别 支持的操作 示例容器
输入迭代器 只读访问、单向前进 单向链表 std::forward_list
输出迭代器 只写访问、单向前进 输出流 std::ostream_iterator
前向迭代器 读写访问、单向前进 向量 std::vector
双向迭代器 读写访问、单向前进和反向迭代 双向链表 std::list
随机访问迭代器 读写访问、单向前进、反向迭代、跳跃移动(支持算术运算) 向量 std::vector、队列 std::deque
无效迭代器(新) 随机访问迭代器的所有功能,且元素在内存中连续排列 新的 C++ 容器如 std::span

3. iterator_category 的作用

iterator_category 是迭代器类型中的一个别名,用于标识该迭代器所属的类别。它是标准库中 迭代器特性(Iterator Traits) 的一部分,标准算法会根据迭代器的类别优化其行为。

为什么需要 iterator_category

标准库中的算法(如 std::sort、std::find 等)需要知道迭代器支持哪些操作,以便选择最优的实现方式。例如:

  • 对于随机访问迭代器,可以使用快速的随机访问算法(如快速排序)。
  • 对于双向迭代器,只能使用适用于双向迭代的算法(如归并排序)。
  • 对于输入迭代器,只能进行单次遍历,许多复杂算法无法使用。

通过指定 iterator_category,你可以让标准算法了解你自定义迭代器的能力,从而选择合适的方法进行操作。

iterator_category 的声明

在你的自定义迭代器类中,通过以下方式声明迭代器类别:

1
using iterator_category = std::bidirectional_iterator_tag;

这表示该迭代器是一个 双向迭代器,支持向前和向后遍历。


4. std::bidirectional_iterator_tag 详解

std::bidirectional_iterator_tag 是一个标签(Tag),用于标识迭代器类别。C++ 标准库中有多个这样的标签,分别对应不同的迭代器类别:

  • std::input_iterator_tag
  • std::output_iterator_tag
  • std::forward_iterator_tag
  • std::bidirectional_iterator_tag
  • std::random_access_iterator_tag
  • std::contiguous_iterator_tag(C++20)

这些标签本质上是空的结构体,用于类型区分。在标准算法中,通常会通过这些标签进行 重载选择(Overload Resolution) 或 特化(Specialization),以实现针对不同迭代器类别的优化。

继承关系

迭代器标签是有继承关系的:

  • std::forward_iterator_tag 继承自 std::input_iterator_tag
  • std::bidirectional_iterator_tag 继承自 std::forward_iterator_tag
  • std::random_access_iterator_tag 继承自 std::bidirectional_iterator_tag
  • std::contiguous_iterator_tag 继承自 std::random_access_iterator_tag

这种继承关系反映了迭代器类别的能力层级。例如,双向迭代器 具备 前向迭代器 的所有能力,加上反向遍历的能力。


5. 迭代器特性(Iterator Traits)详解

C++ 提供了 迭代器特性(Iterator Traits),通过模板类 std::iterator_traits 来获取迭代器的相关信息。通过这些特性,标准算法可以泛化地处理不同类型的迭代器。

迭代器特性包含的信息

std::iterator_traits 提供以下信息:

  • iterator_category:迭代器类别标签。
  • value_type:迭代器指向的元素类型。
  • difference_type:迭代器间的距离类型(通常是 std::ptrdiff_t)。
  • pointer:指向元素的指针类型。
  • reference:对元素的引用类型。

自定义迭代器与 iterator_traits

当你定义自己的迭代器时,确保提供这些类型别名,以便标准库算法能够正确识别和使用你的迭代器。例如:

1
2
3
4
5
6
7
8
9
10
11
template<typename T>
class Iterator {
public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T&;

// 其他成员函数...
};

这样,使用 std::iterator_traits<Iterator<T>> 时,就能正确获取迭代器的特性。


deque:双端队列

用法

deque(双端队列)是一种支持在两端高效插入和删除元素的序列容器。与vector相比,deque支持在前端和后端均以常数时间进行插入和删除操作。

内部实现原理

deque通常由一系列固定大小的数组块组成,这些块通过一个中央映射数组进行管理。这种结构使得在两端扩展时不需要重新分配整个容器的内存,从而避免了vector在前端插入的高成本。

性能特性

  • 随机访问:支持常数时间的随机访问(O(1))。
  • 前后插入/删除:在前端和后端插入和删除元素的操作都是常数时间(O(1))。
  • 中间插入/删除:在中间位置插入或删除元素需要移动元素,时间复杂度为线性时间(O(n))。

应用场景

  • 需要在容器两端频繁插入和删除元素。
  • 需要随机访问元素。
  • 不需要频繁在中间位置插入和删除元素。

双端队列简介

双端队列(deque)是一种序列容器,允许在其两端高效地插入和删除元素。与vector不同,deque不仅支持在末尾添加或删除元素(如vector),还支持在头部进行同样的操作。此外,deque提供了随机访问能力,可以像vector一样通过索引访问元素。

主要特点

  • 双端操作:支持在头部和尾部高效的插入和删除操作。
  • 随机访问:可以像数组和vector一样通过索引访问元素。
  • 动态大小:可以根据需要增长和收缩,无需预先定义大小。

内存分配策略

deque内部并不使用一个单一的连续内存块,而是将元素分割成多个固定大小的块(也称为缓冲区或页面),并通过一个中央映射数组(通常称为map)来管理这些块。具体来说,deque的内部结构可以分为以下几个部分:

  1. 中央映射数组(Map):
    • 一个指针数组,指向各个数据块。
    • map本身也是动态分配的,可以根据需要增长或收缩。
    • map允许deque在两端添加新的数据块,而无需移动现有的数据块。
  2. 数据块(Blocks):
    • 每个数据块是一个固定大小的连续内存区域,用于存储元素。
    • 数据块的大小通常与编译器和平台相关,但在大多数实现中,数据块的大小在运行时是固定的(如512字节或更多,具体取决于元素类型的大小)。
  3. 起始和结束指针:
    • deque维护指向中央映射数组中第一个有效数据块的指针以及第一个无效数据块的指针。
    • 这些指针帮助deque快速地在两端添加或删除数据块。

https://cdn.llfc.club/912f900f6a609de906df07ee849a57f.png

双端队列的操作实现

插入操作

在末尾插入 (push_back)

  1. 检查当前末端数据块的剩余空间:
    • 如果有空间,直接在当前末端数据块中插入新元素。
    • 如果没有空间,分配一个新的数据块,并将其指针添加到map中,然后在新块中插入元素。
  2. 更新末尾指针:
    • 如果分配了新块,末尾指针指向该块的第一个元素。
    • 否则,末尾指针移动到当前末端数据块的下一个位置。

在前端插入 (push_front)

  1. 检查当前前端数据块的剩余空间:
    • 如果有空间,直接在当前前端数据块中插入新元素。
    • 如果没有空间,分配一个新的数据块,并将其指针添加到map的前面,然后在新块中插入元素。
  2. 更新前端指针:
    • 如果分配了新块,前端指针指向该块的最后一个元素。
    • 否则,前端指针移动到当前前端数据块的前一个位置。

删除操作

从末尾删除 (pop_back)

  1. 检查末端数据块是否有元素

    :

    • 如果有,移除最后一个元素,并更新末尾指针。
    • 如果数据块变为空,释放该数据块并从map中移除其指针,然后更新末尾指针指向前一个块。

从前端删除 (pop_front)

  1. 检查前端数据块是否有元素

    :

    • 如果有,移除第一个元素,并更新前端指针。
    • 如果数据块变为空,释放该数据块并从map中移除其指针,然后更新前端指针指向下一个块。

访问操作

随机访问

deque支持通过索引进行随机访问,其内部机制如下:

  1. 计算元素的位置:
    • 根据给定的索引,确定对应的数据块和数据块内的偏移量。
    • 使用map数组定位到具体的块,然后通过偏移量定位到块内的元素。
  2. 访问元素:
    • 一旦定位到具体的位置,即可像数组一样访问元素。

迭代器访问

deque提供双向迭代器,支持使用标准的C++迭代器操作(如++、--等)进行遍历。


双端队列的性能特性

理解deque的内部实现有助于理解其性能特性。以下是deque的主要操作及其时间复杂度:

操作 时间复杂度 说明
随机访问(通过索引) 常数时间(O(1)) 通过计算块和偏移量直接访问元素
插入/删除前端 常数时间(O(1)) 仅涉及前端指针和可能的数据块分配
插入/删除末端 常数时间(O(1)) 仅涉及末端指针和可能的数据块分配
中间插入/删除 线性时间(O(n)) 需要移动数据块内的元素,可能涉及多个块的操作
查找元素 线性时间(O(n)) 需要遍历元素进行查找
插入单个元素 平均常数时间(O(1)) 在前端或末端插入,通常不需移动大量元素
插入大量元素 线性时间(O(n)) 需要分配新的数据块并进行元素复制

优缺点

优点

  • 双端操作高效:在两端插入和删除元素非常快速,不需要移动大量元素。
  • 支持随机访问:可以像vector一样通过索引高效访问元素。
  • 动态增长:无需预先定义大小,可以根据需要自动调整。

缺点

  • 内存碎片:由于使用多个数据块,可能导致内存碎片,尤其是在大量插入和删除操作后。
  • 较低的局部性:元素不连续存储,可能导致缓存未命中率较高,影响性能。
  • 复杂性较高:内部实现相对复杂,不如vector直接高效。

双端队列与其他容器的比较

特性 vector deque list
内存结构 单一连续内存块 多块连续内存,通过映射数组管理 双向链表
随机访问 是,常数时间(O(1)) 是,常数时间(O(1)) 否,需要线性时间(O(n))
前端插入/删除 低效,线性时间(O(n)) 高效,常数时间(O(1)) 高效,常数时间(O(1))
末端插入/删除 高效,常数时间(O(1)) 高效,常数时间(O(1)) 高效,常数时间(O(1))
内存碎片 低,由于单一连续内存块 较高,由于多块内存管理 较高,由于节点分散在内存中
元素隔离 高,局部性较好 中等,分块存储提高了部分局部性 低,元素分散存储,缓存效率低
应用场景 需要频繁随机访问、末端操作的场景 需要频繁在两端插入/删除且偶尔随机访问的场景 需要频繁在中间插入/删除且不需要随机访问的场景

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <iostream>
#include <deque>

int main() {
// 创建一个空的deque
std::deque<std::string> dq;

// 在末尾添加元素
dq.push_back("End1");
dq.push_back("End2");

// 在前端添加元素
dq.push_front("Front1");
dq.push_front("Front2");

// 遍历deque
std::cout << "deque中的元素: ";
for(auto it = dq.begin(); it != dq.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 访问首尾元素
std::cout << "首元素: " << dq.front() << std::endl;
std::cout << "尾元素: " << dq.back() << std::endl;

// 删除首元素
dq.pop_front();

// 删除尾元素
dq.pop_back();

// 打印删除后的deque
std::cout << "删除首尾元素后: ";
for(auto num : dq) {
std::cout << num << " ";
}
std::cout << std::endl;

return 0;
}

输出

1
2
3
4
deque中的元素: Front2 Front1 End1 End2 
首元素: Front2
尾元素: End2
删除首尾元素后: Front1 End1

map和unordered_map:关联数组

map用法与原理

用法

map是一个关联容器,用于存储键值对(key-value)。它基于键自动排序,且每个键都是唯一的。map提供了快速的查找、插入和删除操作。

内部实现原理

map通常使用自平衡的二叉搜索树(如红黑树)实现。这确保了所有操作的时间复杂度为对数时间(O(log n)),且元素按照键的顺序排列。

unordered_map用法与原理

用法

unordered_map也是一种关联容器,用于存储键值对,但它不保证元素的顺序。unordered_map基于哈希表实现,提供了平均常数时间(O(1))的查找、插入和删除操作。

内部实现原理

unordered_map使用哈希表来存储元素。键通过哈希函数转换为哈希值,并映射到特定的桶中。如果多个键映射到同一桶,会通过链表或其他方法解决冲突。

性能对比

操作 map unordered_map
查找 O(log n) 平均 O(1)
插入 O(log n) 平均 O(1)
删除 O(log n) 平均 O(1)
内存使用 较高 较低
元素顺序 有序 无序

应用场景

  • **map**:
    • 需要按键的顺序遍历元素。
    • 需要有序的关联数组。
    • 需要高效的范围查找。
  • **unordered_map**:
    • 对元素顺序没有要求。
    • 需要极高效的查找、插入和删除操作。
    • 不需要自定义的排序规则。

代码示例

map示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <map>
#include <string>

int main() {
// 创建一个空的map,键为string,值为int
std::map<std::string, int> ageMap;

// 插入键值对
ageMap["Alice"] = 30;
ageMap["Bob"] = 25;
ageMap["Charlie"] = 35;

// 查找元素
std::string name = "Bob";
if(ageMap.find(name) != ageMap.end()) {
std::cout << name << " 的年龄是 " << ageMap[name] << std::endl;
} else {
std::cout << "未找到 " << name << std::endl;
}

// 遍历map
std::cout << "所有人员和年龄: " << std::endl;
for(auto it = ageMap.begin(); it != ageMap.end(); ++it) {
std::cout << it->first << " : " << it->second << std::endl;
}

// 删除元素
ageMap.erase("Alice");

// 打印删除后的map
std::cout << "删除Alice后: " << std::endl;
for(auto &[key, value] : ageMap) {
std::cout << key << " : " << value << std::endl;
}

return 0;
}

unordered_map示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <unordered_map>
#include <string>

int main() {
// 创建一个空的unordered_map,键为string,值为double
std::unordered_map<std::string, double> priceMap;

// 插入键值对
priceMap["Apple"] = 1.2;
priceMap["Banana"] = 0.5;
priceMap["Orange"] = 0.8;

// 查找元素
std::string fruit = "Banana";
if(priceMap.find(fruit) != priceMap.end()) {
std::cout << fruit << " 的价格是 $" << priceMap[fruit] << std::endl;
} else {
std::cout << "未找到 " << fruit << std::endl;
}

// 遍历unordered_map
std::cout << "所有水果和价格: " << std::endl;
for(auto &[key, value] : priceMap) {
std::cout << key << " : $" << value << std::endl;
}

// 删除元素
priceMap.erase("Apple");

// 打印删除后的unordered_map
std::cout << "删除Apple后: " << std::endl;
for(auto &[key, value] : priceMap) {
std::cout << key << " : $" << value << std::endl;
}

return 0;
}

输出

map输出

1
2
3
4
5
6
7
8
Bob 的年龄是 25
所有人员和年龄:
Alice : 30
Bob : 25
Charlie : 35
删除Alice后:
Bob : 25
Charlie : 35

unordered_map输出

1
2
3
4
5
6
7
8
Banana 的价格是 $0.5
所有水果和价格:
Apple : $1.2
Banana : $0.5
Orange : $0.8
删除Apple后:
Banana : $0.5
Orange : $0.8

set和unordered_set:集合

set用法与原理

用法

set是一个关联容器,用于存储唯一的、有序的元素。set基于键自动排序,且每个元素都是唯一的。

内部实现原理

set通常使用自平衡的二叉搜索树(如红黑树)实现,保证元素按顺序排列。每次插入元素时,都会自动保持树的平衡,并确保元素的唯一性。

unordered_set用法与原理

用法

unordered_set也是一种集合容器,用于存储唯一的元素,但它不保证元素的顺序。unordered_set基于哈希表实现,提供了平均常数时间(O(1))的查找、插入和删除操作。

内部实现原理

unordered_set使用哈希表存储元素。每个元素通过哈希函数转换为哈希值,并映射到特定的桶中。冲突通过链表或其他方法解决。

性能对比

操作 set unordered_set
查找 O(log n) 平均 O(1)
插入 O(log n) 平均 O(1)
删除 O(log n) 平均 O(1)
内存使用 较高 较低
元素顺序 有序 无序

应用场景

  • **set**:
    • 需要有序的唯一元素集合。
    • 需要按顺序遍历元素。
    • 需要基于区间的操作(如查找、删除某范围的元素)。
  • **unordered_set**:
    • 对元素顺序无要求。
    • 需要极高效的查找、插入和删除操作。
    • 不需要自定义的排序规则。

代码示例

set示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <iostream>
#include <set>

int main() {
// 创建一个空的整数set
std::set<int> numbers;

// 插入元素
numbers.insert(10);
numbers.insert(20);
numbers.insert(30);
numbers.insert(20); // 重复元素,不会被插入

// 遍历set
std::cout << "set中的元素: ";
for(auto it = numbers.begin(); it != numbers.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 查找元素
int key = 20;
if(numbers.find(key) != numbers.end()) {
std::cout << key << " 在set中存在。" << std::endl;
} else {
std::cout << key << " 不在set中。" << std::endl;
}

// 删除元素
numbers.erase(10);

// 打印删除后的set
std::cout << "删除10后set中的元素: ";
for(auto num : numbers) {
std::cout << num << " ";
}
std::cout << std::endl;

return 0;
}

unordered_set示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <iostream>
#include <unordered_set>

int main() {
// 创建一个空的unordered_set
std::unordered_set<int> numbers;

// 插入元素
numbers.insert(10);
numbers.insert(20);
numbers.insert(30);
numbers.insert(20); // 重复元素,不会被插入

// 遍历unordered_set
std::cout << "unordered_set中的元素: ";
for(auto it = numbers.begin(); it != numbers.end(); ++it) {
std::cout << *it << " ";
}
std::cout << std::endl;

// 查找元素
int key = 20;
if(numbers.find(key) != numbers.end()) {
std::cout << key << " 在unordered_set中存在。" << std::endl;
} else {
std::cout << key << " 不在unordered_set中。" << std::endl;
}

// 删除元素
numbers.erase(10);

// 打印删除后的unordered_set
std::cout << "删除10后unordered_set中的元素: ";
for(auto num : numbers) {
std::cout << num << " ";
}
std::cout << std::endl;

return 0;
}

输出

set输出

1
2
3
set中的元素: 10 20 30 
20 在set中存在。
删除10后set中的元素: 20 30

unordered_set输出(注意元素顺序可能不同)

1
2
3
unordered_set中的元素: 10 20 30 
20 在unordered_set中存在。
删除10后unordered_set中的元素: 20 30

stack、queue和priority_queue:容器适配器

用法

STL中的容器适配器(stack、queue、priority_queue)提供了特定的数据结构接口,这些适配器在内部使用其他容器来存储元素(默认使用deque或vector)。

内部实现原理

  • **stack**:后进先出(LIFO)数据结构,通常使用deque或vector作为底层容器,通过限制操作接口来实现。
  • **queue**:先进先出(FIFO)数据结构,通常使用deque作为底层容器,通过限制操作接口来实现。
  • **priority_queue**:基于堆的数据结构,通常使用vector作为底层容器,并通过堆算法(如std::make_heap、std::push_heap、std::pop_heap)维护元素的优先级顺序。

性能特性

  • **stack**:
    • 访问顶部元素:O(1)
    • 插入和删除:O(1)
  • **queue**:
    • 访问前端和后端元素:O(1)
    • 插入和删除:O(1)
  • **priority_queue**:
    • 访问顶部(最大或最小元素):O(1)
    • 插入和删除:O(log n)

应用场景

  • **stack**:
    • 实现函数调用栈。
    • 处理撤销操作。
    • 深度优先搜索(DFS)。
  • **queue**:
    • 实现任务调度。
    • 广度优先搜索(BFS)。
    • 数据流处理。
  • **priority_queue**:
    • 实现优先级调度。
    • 求解最短路径算法(如Dijkstra)。
    • 任意需要按优先级处理元素的场景。

代码示例

stack示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <iostream>
#include <stack>

int main() {
// 创建一个空的stack,底层使用vector
std::stack<int> s;

// 压入元素
s.push(1);
s.push(2);
s.push(3);

// 访问栈顶元素
std::cout << "栈顶元素: " << s.top() << std::endl;

// 弹出元素
s.pop();
std::cout << "弹出一个元素后,新的栈顶: " << s.top() << std::endl;

// 判断栈是否为空
if(!s.empty()) {
std::cout << "栈不为空,元素数量: " << s.size() << std::endl;
}

return 0;
}

queue示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <iostream>
#include <queue>

int main() {
// 创建一个空的queue,底层使用deque
std::queue<std::string> q;

// 入队元素
q.push("First");
q.push("Second");
q.push("Third");

// 访问队首元素
std::cout << "队首元素: " << q.front() << std::endl;

// 访问队尾元素
std::cout << "队尾元素: " << q.back() << std::endl;

// 出队元素
q.pop();
std::cout << "出队后新的队首: " << q.front() << std::endl;

// 判断队列是否为空
if(!q.empty()) {
std::cout << "队列不为空,元素数量: " << q.size() << std::endl;
}

return 0;
}

priority_queue示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <iostream>
#include <queue>
#include <vector>

int main() {
// 创建一个空的priority_queue,默认是最大堆
std::priority_queue<int> pq;

// 插入元素
pq.push(30);
pq.push(10);
pq.push(20);
pq.push(40);

// 访问堆顶元素
std::cout << "优先级最高的元素: " << pq.top() << std::endl;

// 弹出元素
pq.pop();
std::cout << "弹出一个元素后,新的堆顶: " << pq.top() << std::endl;

// 遍历priority_queue(需要复制,因为无法直接遍历)
std::priority_queue<int> copy = pq;
std::cout << "剩余的元素: ";
while(!copy.empty()) {
std::cout << copy.top() << " ";
copy.pop();
}
std::cout << std::endl;

return 0;
}

输出

stack输出

1
2
3
栈顶元素: 3
弹出一个元素后,新的栈顶: 2
栈不为空,元素数量: 2

queue输出

1
2
3
4
队首元素: First
队尾元素: Third
出队后新的队首: Second
队列不为空,元素数量: 2

priority_queue输出

1
2
3
优先级最高的元素: 40
弹出一个元素后,新的堆顶: 30
剩余的元素: 30 20 10

总结

C++ STL提供了丰富多样的容器,适用于各种不同的数据存储和管理需求。理解每种容器的特点、内部实现原理以及性能特性,可以帮助开发者在实际应用中做出最佳的选择,从而编写出高效且可维护的代码。

  • 序列容器:
    • **vector**:适用于需要频繁随机访问和在末尾操作的场景。
    • **list**:适用于需要在中间频繁插入和删除的场景。
    • **deque**:适用于需要在两端频繁插入和删除的场景。
  • 关联容器:
    • **map*和*set**:适用于需要有序存储和快速查找的场景。
    • **unordered_map*和*unordered_set**:适用于需要高效查找且对元素顺序无要求的场景。
  • 容器适配器:
    • **stack**:用于LIFO操作。
    • **queue**:用于FIFO操作。
    • **priority_queue**:用于优先级队列操作。
12…37>

370 posts
17 categories
21 tags
RSS
GitHub ZhiHu
© 2025 恋恋风辰 本站总访问量次 | 本站访客数人
Powered by Hexo
|
Theme — NexT.Muse v5.1.3