1.6. Utility Functions and Classes

This section contains the implementations of utility functions and classes used in this book.

import collections
import inspect
import mlx
import mlx.core as mx
import mlx.nn as nn
from IPython import display
from d2l import mlx as d2l
class HyperParameters: #@save
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        """Defined in :numref:`sec_oo-design`"""
        raise NotImplemented

    def save_hyperparameters(self, ignore=[]):
        """Save function arguments into class attributes.

        Defined in :numref:`sec_utils`"""
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

class DataModule(d2l.HyperParameters): #@save
   """The base class of data.
   Defined in :numref:`subsec_oo-design-models`"""
   def __init__(self, root='../data'):
       self.save_hyperparameters()


   def get_dataloader(self, train):
       raise NotImplementedError


   def train_dataloader(self):
       return self.get_dataloader(train=True)


   def val_dataloader(self):
       return self.get_dataloader(train=False)


   def get_tensorloader(self, tensors, train, indices=slice(0, None)):
       """Defined in :numref:`sec_synthetic-regression-data`"""
       tensors = tuple(a[indices] for a in tensors)
       dataset = d2l.Dataset(*tensors)
       return d2l.DataLoader(dataset, self.batch_size, shuffle=train)

class ProgressBoard(d2l.HyperParameters): #@save
    """The board that plots data points in animation.

    Defined in :numref:`sec_oo-design`"""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

    def draw(self, x, y, label, every_n=1):
        """Defined in :numref:`sec_utils`"""
        Point = collections.namedtuple('Point', ['x', 'y'])
        if not hasattr(self, 'raw_points'):
            self.raw_points = collections.OrderedDict()
            self.data = collections.OrderedDict()
        if label not in self.raw_points:
            self.raw_points[label] = []
            self.data[label] = []
        points = self.raw_points[label]
        line = self.data[label]
        points.append(Point(x, y))
        if len(points) != every_n:
            return
        mean = lambda x: sum(x) / len(x)
        line.append(Point(mean([p.x for p in points]),
                          mean([p.y for p in points])))
        points.clear()
        if not self.display:
            return
        d2l.use_svg_display()
        if self.fig is None:
            self.fig = d2l.plt.figure(figsize=self.figsize)
        plt_lines, labels = [], []
        for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
            plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
                                          linestyle=ls, color=color)[0])
            labels.append(k)
        axes = self.axes if self.axes else d2l.plt.gca()
        if self.xlim: axes.set_xlim(self.xlim)
        if self.ylim: axes.set_ylim(self.ylim)
        if not self.xlabel: self.xlabel = self.x
        axes.set_xlabel(self.xlabel)
        axes.set_ylabel(self.ylabel)
        axes.set_xscale(self.xscale)
        axes.set_yscale(self.yscale)
        axes.legend(plt_lines, labels)
        display.display(self.fig)
        display.clear_output(wait=True)

class Module(nn.Module, d2l.HyperParameters): #@save
    """The base class of models.

    Defined in :numref:`sec_oo-design`"""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def __call__(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, np.array(value),
                ('train_' if train else 'val_') + key,
                every_n=int(n))

    def training_step(self, batch):
        def loss_fn(X, y):
            return self.loss(self(*X), y)
        l, grad = nn.value_and_grad(self, loss_fn)((batch[:-1]), batch[-1])
        self.plot('loss', l.item(), train=True)
        return l, grad

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l.item(), train=False)

    def configure_optimizers(self):
        raise NotImplementedError

    def configure_optimizers(self):
        """Defined in :numref:`sec_classification`"""
        return optim.SGD(learning_rate=self.lr)


class RNNScratch(d2l.Module): #@save
    """The RNN model implemented from scratch.

    Defined in :numref:`sec_rnn-scratch`"""
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.W_xh = mx.random.normal(shape=(num_inputs, num_hiddens)) * sigma
        self.W_hh = mx.random.normal(shape=(num_hiddens, num_hiddens)) * sigma
        self.b_h = mx.zeros(num_hiddens)

    def forward(self, inputs, state=None):
        """Defined in :numref:`sec_rnn-scratch`"""
        if state is None:
            # Initial state with shape: (batch_size, num_hiddens)
            state = mx.zeros((inputs.shape[1], self.num_hiddens))
        else:
            state, = state
        outputs = []
        for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs)
            state = mx.tanh(mx.matmul(X, self.W_xh) +
                             mx.matmul(state, self.W_hh) + self.b_h)
            outputs.append(state)
        return outputs, state

    def __call__(self, inputs, state=None):
        """Defined in :numref:`sec_rnn-scratch`"""
        if state is None:
            # Initial state with shape: (batch_size, num_hiddens)
            state = mx.zeros((inputs.shape[1], self.num_hiddens))
        else:
            state, = state
        outputs = []
        for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs)
            state = mx.tanh(mx.matmul(X, self.W_xh) +
                             mx.matmul(state, self.W_hh) + self.b_h)
            outputs.append(state)
        return outputs, state

def check_len(a, n): #@save
    """Check the length of a list.

    Defined in :numref:`sec_rnn-scratch`"""
    assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'

def check_shape(a, shape): #@save
    """Check the shape of a tensor.

    Defined in :numref:`sec_rnn-scratch`"""
    assert a.shape == shape, \
            f'tensor\'s shape {a.shape} != expected shape {shape}'


class RNN(d2l.Module): #@save
    """The RNN model implemented with high-level APIs.

    Defined in :numref:`sec_rnn-concise`"""
    def __init__(self, num_inputs, num_hiddens):
        super().__init__()
        self.save_hyperparameters()
        self.rnn = d2l.RNNScratch(num_inputs, num_hiddens)

    def __call__(self, inputs, H=None):
        return self.rnn(inputs, H)

class GRUScratch(d2l.Module): #@save
   """Defined in :numref:`sec_gru`"""
   def __init__(self, num_inputs, num_hiddens, sigma=0.01, dropout=0):
       super().__init__()
       self.save_hyperparameters()
       init_weight = lambda *shape: mx.random.normal(*shape) * sigma
       triple = lambda: (init_weight((num_inputs, num_hiddens)),
                         init_weight((num_hiddens, num_hiddens)),
                         mx.zeros(shape=(num_hiddens,)))
       self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
       self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
       self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state


   def __call__(self, inputs, H=None):
       """Defined in :numref:`sec_gru`"""
       if H is None:
           # Initial state with shape: (batch_size, num_hiddens)
           H = mx.zeros((inputs.shape[1], self.num_hiddens))
       outputs = []
       for X in inputs:
           Z = mx.sigmoid(mx.matmul(X, self.W_xz) +
                           mx.matmul(H, self.W_hz) + self.b_z)
           R = mx.sigmoid(mx.matmul(X, self.W_xr) +
                           mx.matmul(H, self.W_hr) + self.b_r)
           H_tilde = mx.tanh(mx.matmul(X, self.W_xh) +
                              mx.matmul(R * H, self.W_hh) + self.b_h)
           H = Z * H + (1 - Z) * H_tilde
           H = d2l.dropout_layer(H, self.dropout)
           outputs.append(H)
       return outputs, H

class StackedGRUScratch(d2l.Module): #@save
    """Defined in :numref:`sec_deep_rnn`"""
    def __init__(self, num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=0):
        super().__init__()
        self.save_hyperparameters()
        self.grus = nn.Sequential(*[d2l.GRUScratch(
            num_inputs if i==0 else num_hiddens, num_hiddens, sigma)
                                    for i in range(num_layers)])

    def __call__(self, inputs, Hs=None):
        outputs = inputs
        if Hs is None: Hs = [None] * self.num_layers
        for i in range(self.num_layers):
            outputs, Hs[i] = self.grus.layers[i](outputs, Hs[i])
            outputs = mx.stack(outputs, axis=0)
        return outputs, Hs

class GRU(d2l.RNN): #@save
    """The multilayer GRU model.

    Defined in :numref:`sec_deep_rnn`"""
    def __init__(self, num_inputs, num_hiddens, num_layers, dropout=0):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = d2l.StackedGRUScratch(num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=dropout)

def dropout_layer(X, dropout): #@save
    """Defined in :numref:`sec_dropout`"""
    assert 0 <= dropout <= 1
    if dropout == 1: return mx.zeros_like(X)
    mask = (mx.random.uniform(shape=X.shape) > dropout).astype(mx.float32)
    return mask * X / (1.0 - dropout)

def add_to_class(Class): #@save
    """Register functions as methods in created class.

    Defined in :numref:`sec_oo-design`"""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

def num_gpus(): #@save
    """Get the number of available GPUs.

    Defined in :numref:`sec_use_gpu`"""
    return 1

def gpu(i=0): #@save
    """Get a GPU device.

    Defined in :numref:`sec_use_gpu`"""
    mx.set_default_device(mx.gpu)
    return mx.default_device()

def try_gpu(i=0): #@save
    """Return gpu(i) if exists, otherwise return cpu().

    Defined in :numref:`sec_use_gpu`"""
    return gpu(0)

class Trainer(d2l.HyperParameters): #@save
    """The base class for training models with data.

    Defined in :numref:`subsec_oo-design-models`"""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

    def prepare_batch(self, batch):
        """Defined in :numref:`sec_linear_scratch`"""
        return batch

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train(True)
        for batch in self.train_dataloader:
            loss, grads = self.model.training_step(self.prepare_batch(batch))
            if self.gradient_clip_val > 0:
                grads = self.clip_gradients(self.gradient_clip_val, grads)
            self.optim.update(model=self.model, gradients=grads)
            mx.eval(self.model.parameters())
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1

    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        """Defined in :numref:`sec_use_gpu`"""
        self.save_hyperparameters()
        self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]


    def prepare_batch(self, batch):
        """Defined in :numref:`sec_use_gpu`"""
        if self.gpus:
            gpu()
            batch = [a for a in batch]
        return batch


    def prepare_model(self, model):
        """Defined in :numref:`sec_use_gpu`"""
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        if self.gpus:
            gpu()
        self.model = model

    def clip_gradients(self, grad_clip_val, grads):
        """Defined in :numref:`sec_rnn-scratch`

        Defined in :numref:`sec_rnn-scratch`"""
        grad_leaves = mlx.utils.tree_flatten(grads)
        norm = mx.sqrt(sum((x[1] ** 2).sum() for x in grad_leaves))
        clip = lambda grad: mx.where(norm < grad_clip_val, grad, grad * (grad_clip_val / norm))
        return mlx.utils.tree_map(clip, grads)


def download_new(url, folder='../data', sha1_hash=None): #@save
    """Download a file to folder and return the local filepath.

    Defined in :numref:`sec_utils`"""
    if not url.startswith('http'):
        # For back compatability
        url, sha1_hash = DATA_HUB[url]
    os.makedirs(folder, exist_ok=True)
    fname = os.path.join(folder, url.split('/')[-1])
    # Check if hit cache
    if os.path.exists(fname) and sha1_hash:
        sha1 = hashlib.sha1()
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == sha1_hash:
            return fname
    # Download
    print(f'Downloading {fname} from {url}...')
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname

def extract(filename, folder=None): #@save
    """Extract a zip/tar file into folder.

    Defined in :numref:`sec_utils`"""
    base_dir = os.path.dirname(filename)
    _, ext = os.path.splitext(filename)
    assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'
    if ext == '.zip':
        fp = zipfile.ZipFile(filename, 'r')
    else:
        fp = tarfile.open(filename, 'r')
    if folder is None:
        folder = base_dir
    fp.extractall(folder)

class Classifier(d2l.Module): #@save
    """The base class of classification models.

    Defined in :numref:`sec_classification`"""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

    def accuracy(self, Y_hat, Y, averaged=True):
        """Compute the number of correct predictions.

        Defined in :numref:`sec_classification`"""
        Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
        preds = Y_hat.argmax(axis=1).astype(Y.dtype)
        compare = (preds == Y.reshape(-1)).astype(mx.float32)
        return compare.mean() if averaged else compare

    def loss(self, Y_hat, Y, averaged=True):
        """Defined in :numref:`sec_softmax_concise`"""
        Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
        Y = Y.reshape((-1,))
        return nn.losses.cross_entropy(
            Y_hat, Y, reduction='mean' if averaged else 'none')

    def layer_summary(self, X_shape):
        """Defined in :numref:`sec_lenet`"""
        X = mx.random.normal(shape=(X_shape))
        for layer in self.net.layers:
            X = layer(X)
            print(layer.__class__.__name__, 'output shape:\t', X.shape)


class MTFraEng(d2l.DataModule): #@save
    """The English-French dataset.

    Defined in :numref:`sec_machine_translation`"""
    def _download_new(self):
        d2l.extract(d2l.download_new(
            d2l.DATA_URL+'fra-eng.zip', self.root,
            '94646ad1522d915e7b0f9296181140edcf86a4f5'))
        with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
            return f.read()

    def _preprocess(self, text):
        """Defined in :numref:`sec_machine_translation`"""
        # Replace non-breaking space with space
        text = text.replace('\u202f', ' ').replace('\xa0', ' ')
        # Insert space between words and punctuation marks
        no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
        out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
               for i, char in enumerate(text.lower())]
        return ''.join(out)

    def _tokenize(self, text, max_examples=None):
        """Defined in :numref:`sec_machine_translation`"""
        src, tgt = [], []
        for i, line in enumerate(text.split('\n')):
            if max_examples and i > max_examples: break
            parts = line.split('\t')
            if len(parts) == 2:
                # Skip empty tokens
                src.append([t for t in f'{parts[0]} <eos>'.split(' ') if t])
                tgt.append([t for t in f'{parts[1]} <eos>'.split(' ') if t])
        return src, tgt

    def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
        """Defined in :numref:`sec_machine_translation`"""
        super(MTFraEng, self).__init__()
        self.save_hyperparameters()
        self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
            self._download_new())

    def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
        """Defined in :numref:`sec_machine_translation`"""
        super(MTFraEng, self).__init__()
        self.save_hyperparameters()
        self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
            self._download_new())

    def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
        """Defined in :numref:`subsec_mt_data_loading`"""
        def _build_array(sentences, vocab, is_tgt=False):
            pad_or_trim = lambda seq, t: (
                seq[:t] if len(seq) > t else seq + ['<pad>'] * (t - len(seq)))
            sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
            if is_tgt:
                sentences = [['<bos>'] + s for s in sentences]
            if vocab is None:
                vocab = d2l.Vocab(sentences, min_freq=2)
            array = mx.array([vocab[s] for s in sentences])
            valid_len = (array != vocab['<pad>']).astype(mx.int32).sum(1)
            return array, vocab, valid_len
        src, tgt = self._tokenize(self._preprocess(raw_text),
                                  self.num_train + self.num_val)
        src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
        tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
        return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
                src_vocab, tgt_vocab)

    def get_dataloader(self, train):
        """Defined in :numref:`subsec_mt_data_loading`"""
        idx = slice(0, self.num_train) if train else slice(self.num_train, None)
        return self.get_tensorloader(self.arrays, train, idx)

    def build(self, src_sentences, tgt_sentences):
        """Defined in :numref:`subsec_mt_data_loading`"""
        raw_text = '\n'.join([src + '\t' + tgt for src, tgt in zip(
            src_sentences, tgt_sentences)])
        arrays, _, _ = self._build_arrays(
            raw_text, self.src_vocab, self.tgt_vocab)
        return arrays


class Encoder(d2l.Module):  #@save
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self):
        super().__init__()

    def __call__(self, X, *args):
        raise NotImplementedError

class Decoder(d2l.Module):  #@save
    """编码器-解码器架构的基本解码器接口

    Defined in :numref:`sec_encoder-decoder`"""
    def __init__(self):
        super().__init__()

    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def __call__(self, X, state):
        raise NotImplementedError

class EncoderDecoder(d2l.Classifier):   #@save
    """编码器-解码器架构的基类

    Defined in :numref:`sec_encoder-decoder`"""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def __call__(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        return self.decoder(dec_X, dec_state)[0]

    def predict_step(self, batch, device, num_steps,
                     save_attention_weights=False):
        """Defined in :numref:`sec_seq2seq_decoder`"""
        src, tgt, src_valid_len, _ = batch
        enc_all_outputs = self.encoder(src, src_valid_len)
        dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len)
        outputs, attention_weights = [tgt[:, (0)].reshape(-1 ,1), ], []
        for _ in range(num_steps):
            Y, dec_state = self.decoder(outputs[-1], dec_state)
            outputs.append(Y.argmax(2))
            # Save attention weights (to be covered later)
            if save_attention_weights:
                attention_weights.append(self.decoder.attention_weights)
        return mx.concatenate(outputs[1:], 1), attention_weights

def init_seq2seq(array):    #@save
    """Defined in :numref:`sec_seq2seq`"""
    if array.ndim > 1:
        weight_fn = nn.init.glorot_uniform()
        array = weight_fn(array)
    return array

class Seq2SeqEncoder(d2l.Encoder):  #@save
    """用于序列到序列学习的循环神经网络编码器

    Defined in :numref:`sec_seq2seq`"""
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout)
        # print("before init:", tree_flatten(self.parameters()))
        for module in self.modules():
            if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU):
                module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters()))

    def __call__(self, X, *args):
        # X shape: (batch_size, num_steps)
        embs = self.embedding(X.T.astype(mx.int64))
        # embs shape: (num_steps, batch_size, embed_size)
        outputs, state = self.rnn(embs)
        state = mx.array(state)
        # outputs shape: (num_steps, batch_size, num_hiddens)
        # state shape: (num_layers, batch_size, num_hiddens)
        return outputs, state

class Seq2SeqDecoder(d2l.Decoder):  #@save
    """The RNN decoder for sequence to sequence learning.

    Defined in :numref:`sec_seq2seq`"""
    def __init__(self, num_inputs, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = d2l.GRU(embed_size+num_hiddens, num_hiddens,
                           num_layers, dropout)
        self.dense = nn.Linear(num_inputs, vocab_size)
        # print("before init:", tree_flatten(self.parameters()))
        for module in self.modules():
            if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU):
                module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters()))

    def init_state(self, enc_all_outputs, *args):
        return enc_all_outputs

    def __call__(self, X, state):
        # X shape: (batch_size, num_steps)
        # embs shape: (num_steps, batch_size, embed_size)
        embs = self.embedding(X.T.astype(mx.int32))
        enc_output, hidden_state = state
        # context shape: (batch_size, num_hiddens)
        context = enc_output[-1]
        # Broadcast context to (num_steps, batch_size, num_hiddens)
        context = mx.tile(context, (embs.shape[0], 1, 1))
        # Concat at the feature dimension
        embs_and_context = mx.concatenate((embs, context), -1)
        outputs, hidden_state = self.rnn(embs_and_context, hidden_state)
        outputs = self.dense(outputs).swapaxes(0, 1)
        # outputs shape: (batch_size, num_steps, vocab_size)
        # hidden_state shape: (num_layers, batch_size, num_hiddens)
        return outputs, [enc_output, hidden_state]

class Seq2Seq(d2l.EncoderDecoder):  #@save
    """The RNN encoder--decoder for sequence to sequence learning.

    Defined in :numref:`sec_seq2seq_decoder`"""
    def __init__(self, encoder, decoder, tgt_pad, lr):
        super().__init__(encoder, decoder)
        self.save_hyperparameters()

    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)

    def configure_optimizers(self):
        # Adam optimizer is used here
        return optim.Adam(learning_rate=self.lr)

    def loss(self, Y_hat, Y):
        """Defined in :numref:`sec_seq2seq_decoder`"""
        l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False)
        mask = (Y.reshape(-1) != self.tgt_pad).astype(mx.float32)
        return (l * mask).sum() / mask.sum()

def bleu(pred_seq, label_seq, k):   #@save
    """计算BLEU

    Defined in :numref:`sec_seq2seq_decoder`"""
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, min(k, len_pred) + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score


class Dataset:  #@save
    def __init__(self, *tensors):
        """Defined in :numref:`sec_fashion_mnist`"""
        assert all(
            tensors[0].shape[0] == tensor.shape[0] for tensor in tensors
        ), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(mx.array(tensor[index]) for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].shape[0]

class DataLoader:   #@save
    def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
        """Defined in :numref:`sec_fashion_mnist`"""
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.indices = list(range(len(dataset)))
        self.current_index = 0
        if self.shuffle:
            random.shuffle(self.indices)

    def __iter__(self):
        self.current_index = 0
        if self.shuffle:
            random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.current_index >= len(self.indices):
            raise StopIteration

        end_index = self.current_index + self.batch_size
        if end_index > len(self.indices):
            if self.drop_last:
                raise StopIteration
            else:
                end_index = len(self.indices)

        batch_indices = self.indices[self.current_index:end_index]
        batch = [self.dataset[i] for i in batch_indices]
        self.current_index = end_index

        return self.collate_fn(batch)

    def __len__(self):
        if self.drop_last:
            return len(self.dataset) // self.batch_size
        else:
            return math.ceil(len(self.dataset) / self.batch_size)

    def collate_fn(self, batch):
        if isinstance(batch[0], tuple):
            if (len(batch[0])) == 2:
                data, targets = zip(*batch)
                data = mx.array(data)
                targets = mx.array(targets)
                return data, targets
            if (len(batch[0])) == 4:
                data, decoder_input, src_valid_len, targets = zip(*batch)
                data = mx.array(data)
                decoder_input = mx.array(decoder_input)
                src_valid_len = mx.array(src_valid_len)
                targets = mx.array(targets)
                return data, decoder_input, src_valid_len, targets
        return mx.array(batch)



class AdaptiveAvgPool2d(nn.Module): #@save
    """Applies a 2D adaptive average pooling over an input signal.

    The output spatial dimensions are specified by `output_size`.
    This implementation uses a standard `mlx.nn.AvgPool2d` layer with
    dynamically calculated kernel size and stride to achieve the target
    output size. It is primarily designed for downsampling (input size >= output size).

    Args:
        output_size (int or tuple): The target output size of the image
            of the form H x W. Can be a single integer H to specify H x H,
            or a tuple of two integers (H, W).
    """
    def __init__(self, output_size):
        super().__init__()
        if isinstance(output_size, int):
            self.output_h = output_size
            self.output_w = output_size
        elif isinstance(output_size, tuple) and len(output_size) == 2 and \
             isinstance(output_size[0], int) and isinstance(output_size[1], int):
            self.output_h = output_size[0]
            self.output_w = output_size[1]
        else:
            raise ValueError(
                "output_size must be an int or a tuple of two positive ints"
            )

        if self.output_h <= 0 or self.output_w <= 0:
            raise ValueError("output_size dimensions must be positive")

    def __call__(self, x: mx.array) -> mx.array:
        """
        Forward pass for AdaptiveAvgPool2d.

        Args:
            x (mx.array): Input tensor of shape (N, H_in, W_in, C).
                          MLX typically uses channels-last format.

        Returns:
            mx.array: Output tensor of shape (N, H_out, W_out, C).
        """
        input_h = x.shape[1]
        input_w = x.shape[2]

        if self.output_h == input_h and self.output_w == input_w:
            return x

        # Handle the common Global Average Pooling case efficiently
        if self.output_h == 1 and self.output_w == 1:
            return mx.mean(x, axis=(1, 2), keepdims=True)

        # This implementation relies on nn.AvgPool2d and is suited for downsampling.
        if input_h < self.output_h or input_w < self.output_w:
            raise ValueError(
                f"Input spatial size ({input_h}x{input_w}) is smaller than target output size "
                f"({self.output_h}x{self.output_w}) in at least one dimension. "
                "This AdaptiveAvgPool2d implementation using nn.AvgPool2d "
                "requires input dimensions to be greater than or equal to output dimensions."
            )

        # Calculate stride
        # stride_h = math.floor(input_h / self.output_h) # floor is implicit with // for positive
        stride_h = input_h // self.output_h
        stride_w = input_w // self.output_w

        # Calculate kernel size using the formula: K = I - (O - 1) * S
        # This ensures that an AvgPool2d operation with this kernel and stride
        # will produce an output of the target size O.
        kernel_h = input_h - (self.output_h - 1) * stride_h
        kernel_w = input_w - (self.output_w - 1) * stride_w

        # Ensure kernel and stride are valid (should be if input_h >= output_h etc.)
        if kernel_h <= 0 or kernel_w <= 0 or stride_h <= 0 or stride_w <= 0:
            # This case should ideally be caught by the input_h < self.output_h checks,
            # or if output_h/output_w are non-positive (checked in __init__).
            raise RuntimeError(
                f"Calculated invalid pooling parameters: "
                f"kernel=({kernel_h},{kernel_w}), stride=({stride_h},{stride_w}). "
                f"Input: ({input_h},{input_w}), Output: ({self.output_h},{self.output_w})"
            )

        pool_layer = nn.AvgPool2d(
            kernel_size=(kernel_h, kernel_w),
            stride=(stride_h, stride_w),
            padding=0  # Adaptive pooling typically does not involve explicit padding
        )
        return pool_layer(x)

class resnet18(nn.Module):  #@save
    """稍加修改的ResNet-18模型
    Defined in :numref:`sec_multi_gpu_concise`"""
    def __init__(self, num_classes, in_channels=1):
        super(resnet18, self).__init__()
        def resnet_block(in_channels, out_channels, num_residuals,
                        first_block=False):
            blk = []
            for i in range(num_residuals):
                if i == 0 and not first_block:
                    blk.append(d2l.Residual(in_channels, out_channels,
                                            use_1x1conv=True, strides=2))
                else:
                    blk.append(d2l.Residual(out_channels, out_channels))
            #return nn.Sequential(*blk)
            return blk

        b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2,padding=3),
            nn.BatchNorm(64), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
        b3 = nn.Sequential(*resnet_block(64, 128, 2))
        b4 = nn.Sequential(*resnet_block(128, 256, 2))
        b5 = nn.Sequential(*resnet_block(256, 512, 2))

        self.layers = nn.Sequential(b1, b2, b3, b4, b5,
            # nn.AvgPool2d((3,3)),
            d2l.AdaptiveAvgPool2d((1, 1)),
            d2l.Flatten(),
            nn.Linear(512, num_classes))

    def __call__(self, x):
        # 该模型使用了更小的卷积核、步长和填充,而且删除了最大汇聚层
        x = self.layers(x)
        return x