Skip to main content

Building Custom Embedders: A Deep Dive

· 6 min read
Ali Kahwaji
Creator & Maintainer of RAG Pipeline Utils

Custom embedders allow you to tailor vector representations to your specific domain, improving retrieval accuracy and reducing dependency on external APIs. This deep dive explores the architecture, implementation patterns, and optimization techniques for building production-ready custom embedders.

Why Custom Embedders?

While pre-trained embedders like OpenAI's text-embedding-ada-002 provide excellent general-purpose embeddings, custom embedders offer several advantages:

  • Domain Specialization: Train on domain-specific vocabulary and concepts
  • Cost Control: Eliminate per-request API costs for high-volume applications
  • Latency Optimization: Reduce network overhead with local inference
  • Privacy: Process sensitive data without external API calls
  • Customization: Fine-tune for specific similarity metrics or tasks

Embedder Contract

RAG Pipeline Utils defines a simple but powerful embedder contract:

interface Embedder {
embed(text: string): Promise<number[]>;
embedBatch?(texts: string[]): Promise<number[][]>;
}

The contract requires two methods:

  • embed() - Transforms a single text into a vector
  • embedBatch() - Optional batch processing for efficiency

Implementation Patterns

Pattern 1: Wrapping Pre-trained Models

The most common pattern wraps pre-trained models from HuggingFace, Sentence Transformers, or custom checkpoints:

const { AutoTokenizer, AutoModel } = require("@xenova/transformers");

class SentenceTransformerEmbedder {
constructor(options = {}) {
this.modelName =
options.modelName || "sentence-transformers/all-MiniLM-L6-v2";
this.maxLength = options.maxLength || 512;
this.tokenizer = null;
this.model = null;
}

async initialize() {
if (!this.tokenizer) {
this.tokenizer = await AutoTokenizer.from_pretrained(this.modelName);
this.model = await AutoModel.from_pretrained(this.modelName);
}
}

async embed(text) {
await this.initialize();

// Tokenize
const inputs = await this.tokenizer(text, {
padding: true,
truncation: true,
max_length: this.maxLength,
return_tensors: "pt",
});

// Generate embeddings
const outputs = await this.model(inputs);

// Mean pooling
const embedding = this.meanPooling(
outputs.last_hidden_state,
inputs.attention_mask,
);

// Normalize
return this.normalize(embedding);
}

async embedBatch(texts) {
await this.initialize();

const inputs = await this.tokenizer(texts, {
padding: true,
truncation: true,
max_length: this.maxLength,
return_tensors: "pt",
});

const outputs = await this.model(inputs);
const embeddings = this.meanPooling(
outputs.last_hidden_state,
inputs.attention_mask,
);

return embeddings.map((emb) => this.normalize(emb));
}

meanPooling(lastHiddenState, attentionMask) {
// Implement mean pooling strategy
const masked = lastHiddenState.mul(attentionMask.unsqueeze(-1));
const summed = masked.sum(1);
const counts = attentionMask
.sum(1)
.unsqueeze(-1)
.clamp((min = 1));
return summed.div(counts);
}

normalize(embedding) {
const norm = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
return embedding.map((val) => val / norm);
}
}

module.exports = SentenceTransformerEmbedder;

Pattern 2: Fine-Tuned Domain Models

For specialized domains, fine-tune base models on your corpus:

# Training script (Python)
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# Load base model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Prepare training data
train_examples = [
InputExample(texts=['medical diagnosis', 'patient symptoms'], label=0.9),
InputExample(texts=['billing code', 'insurance claim'], label=0.8),
# Add domain-specific pairs
]

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# Fine-tune with contrastive loss
train_loss = losses.CosineSimilarityLoss(model)

model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=3,
warmup_steps=100
)

# Save fine-tuned model
model.save('models/medical-embedder-v1')

Use in Node.js:

class FineTunedMedicalEmbedder {
constructor() {
this.modelPath = "./models/medical-embedder-v1";
}

async embed(text) {
// Load fine-tuned checkpoint
const model = await AutoModel.from_pretrained(this.modelPath);
// ... embedding logic
}
}

Pattern 3: Hybrid Embedders

Combine multiple embedding sources for enhanced retrieval:

class HybridEmbedder {
constructor(options) {
this.semanticEmbedder = new SentenceTransformerEmbedder();
this.lexicalEmbedder = new TFIDFEmbedder();
this.semanticWeight = options.semanticWeight || 0.7;
this.lexicalWeight = options.lexicalWeight || 0.3;
}

async embed(text) {
const [semantic, lexical] = await Promise.all([
this.semanticEmbedder.embed(text),
this.lexicalEmbedder.embed(text),
]);

// Weighted combination
return this.combine(semantic, lexical);
}

combine(semantic, lexical) {
// Normalize dimensions if different
const maxDim = Math.max(semantic.length, lexical.length);

const combined = new Array(maxDim).fill(0);

for (let i = 0; i < semantic.length; i++) {
combined[i] += semantic[i] * this.semanticWeight;
}

for (let i = 0; i < lexical.length; i++) {
combined[i] += lexical[i] * this.lexicalWeight;
}

return combined;
}
}

Optimization Techniques

1. Caching

Cache embeddings for frequently accessed texts:

class CachedEmbedder {
constructor(baseEmbedder, options = {}) {
this.baseEmbedder = baseEmbedder;
this.cache = new Map();
this.maxCacheSize = options.maxCacheSize || 10000;
}

async embed(text) {
const hash = this.hash(text);

if (this.cache.has(hash)) {
return this.cache.get(hash);
}

const embedding = await this.baseEmbedder.embed(text);

// Implement LRU eviction
if (this.cache.size >= this.maxCacheSize) {
const firstKey = this.cache.keys().next().value;
this.cache.delete(firstKey);
}

this.cache.set(hash, embedding);
return embedding;
}

hash(text) {
// Use fast hash for cache keys
return require("crypto").createHash("sha256").update(text).digest("hex");
}
}

2. Batch Processing

Maximize GPU utilization with batching:

class BatchOptimizedEmbedder {
constructor(baseEmbedder, options = {}) {
this.baseEmbedder = baseEmbedder;
this.batchSize = options.batchSize || 32;
this.queue = [];
this.processing = false;
}

async embed(text) {
return new Promise((resolve, reject) => {
this.queue.push({ text, resolve, reject });

if (!this.processing) {
this.processBatch();
}
});
}

async processBatch() {
this.processing = true;

while (this.queue.length > 0) {
const batch = this.queue.splice(0, this.batchSize);
const texts = batch.map((item) => item.text);

try {
const embeddings = await this.baseEmbedder.embedBatch(texts);

batch.forEach((item, index) => {
item.resolve(embeddings[index]);
});
} catch (error) {
batch.forEach((item) => item.reject(error));
}
}

this.processing = false;
}
}

3. Quantization

Reduce memory footprint with int8 quantization:

class QuantizedEmbedder {
async embed(text) {
const embedding = await this.baseEmbedder.embed(text);

// Quantize to int8 (-128 to 127)
const min = Math.min(...embedding);
const max = Math.max(...embedding);
const scale = 255 / (max - min);

return embedding.map((val) => {
const quantized = Math.round((val - min) * scale) - 128;
return Math.max(-128, Math.min(127, quantized));
});
}

dequantize(quantized, min, max) {
const scale = (max - min) / 255;
return quantized.map((val) => (val + 128) * scale + min);
}
}

Integration with RAG Pipeline Utils

const { createRagPipeline } = require("@devilsdev/rag-pipeline-utils");
const CustomEmbedder = require("./embedders/custom-embedder");

// Create embedder with optimizations
const embedder = new CachedEmbedder(
new BatchOptimizedEmbedder(
new SentenceTransformerEmbedder({
modelName: "sentence-transformers/all-mpnet-base-v2",
}),
),
);

// Use in pipeline
const pipeline = createRagPipeline({
embedder,
retriever: myRetriever,
llm: myLLM,
});

// Embed documents
await pipeline.ingest("./documents");

// Query uses same embedder
const result = await pipeline.query("How does authentication work?");

Performance Benchmarks

Comparative performance on 10,000 documents:

EmbedderThroughputLatency (p95)Memory
OpenAI API50 texts/sec200msMinimal
Sentence Transformers (CPU)100 texts/sec100ms500MB
Sentence Transformers (GPU)500 texts/sec20ms2GB
Cached + Batched1000 texts/sec10ms1GB

Best Practices

  1. Choose the Right Base Model: Start with proven models like all-mpnet-base-v2 or all-MiniLM-L6-v2
  2. Fine-tune When Necessary: Only fine-tune when general models underperform
  3. Implement Caching: Cache embeddings for static documents
  4. Use Batch Processing: Always implement batch methods for large-scale operations
  5. Monitor Quality: Track retrieval metrics to ensure embedding quality
  6. Version Models: Use model versioning to enable A/B testing

Conclusion

Custom embedders provide flexibility and control for production RAG systems. By understanding the embedder contract, implementation patterns, and optimization techniques, you can build embedders tailored to your domain while maintaining high performance and cost efficiency.

Further Reading