Building Custom Embedders: A Deep Dive
Custom embedders allow you to tailor vector representations to your specific domain, improving retrieval accuracy and reducing dependency on external APIs. This deep dive explores the architecture, implementation patterns, and optimization techniques for building production-ready custom embedders.
Why Custom Embedders?
While pre-trained embedders like OpenAI's text-embedding-ada-002 provide excellent general-purpose embeddings, custom embedders offer several advantages:
- Domain Specialization: Train on domain-specific vocabulary and concepts
- Cost Control: Eliminate per-request API costs for high-volume applications
- Latency Optimization: Reduce network overhead with local inference
- Privacy: Process sensitive data without external API calls
- Customization: Fine-tune for specific similarity metrics or tasks
Embedder Contract
RAG Pipeline Utils defines a simple but powerful embedder contract:
interface Embedder {
embed(text: string): Promise<number[]>;
embedBatch?(texts: string[]): Promise<number[][]>;
}
The contract requires two methods:
embed()- Transforms a single text into a vectorembedBatch()- Optional batch processing for efficiency
Implementation Patterns
Pattern 1: Wrapping Pre-trained Models
The most common pattern wraps pre-trained models from HuggingFace, Sentence Transformers, or custom checkpoints:
const { AutoTokenizer, AutoModel } = require("@xenova/transformers");
class SentenceTransformerEmbedder {
constructor(options = {}) {
this.modelName =
options.modelName || "sentence-transformers/all-MiniLM-L6-v2";
this.maxLength = options.maxLength || 512;
this.tokenizer = null;
this.model = null;
}
async initialize() {
if (!this.tokenizer) {
this.tokenizer = await AutoTokenizer.from_pretrained(this.modelName);
this.model = await AutoModel.from_pretrained(this.modelName);
}
}
async embed(text) {
await this.initialize();
// Tokenize
const inputs = await this.tokenizer(text, {
padding: true,
truncation: true,
max_length: this.maxLength,
return_tensors: "pt",
});
// Generate embeddings
const outputs = await this.model(inputs);
// Mean pooling
const embedding = this.meanPooling(
outputs.last_hidden_state,
inputs.attention_mask,
);
// Normalize
return this.normalize(embedding);
}
async embedBatch(texts) {
await this.initialize();
const inputs = await this.tokenizer(texts, {
padding: true,
truncation: true,
max_length: this.maxLength,
return_tensors: "pt",
});
const outputs = await this.model(inputs);
const embeddings = this.meanPooling(
outputs.last_hidden_state,
inputs.attention_mask,
);
return embeddings.map((emb) => this.normalize(emb));
}
meanPooling(lastHiddenState, attentionMask) {
// Implement mean pooling strategy
const masked = lastHiddenState.mul(attentionMask.unsqueeze(-1));
const summed = masked.sum(1);
const counts = attentionMask
.sum(1)
.unsqueeze(-1)
.clamp((min = 1));
return summed.div(counts);
}
normalize(embedding) {
const norm = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
return embedding.map((val) => val / norm);
}
}
module.exports = SentenceTransformerEmbedder;
Pattern 2: Fine-Tuned Domain Models
For specialized domains, fine-tune base models on your corpus:
# Training script (Python)
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
# Load base model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Prepare training data
train_examples = [
InputExample(texts=['medical diagnosis', 'patient symptoms'], label=0.9),
InputExample(texts=['billing code', 'insurance claim'], label=0.8),
# Add domain-specific pairs
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
# Fine-tune with contrastive loss
train_loss = losses.CosineSimilarityLoss(model)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=3,
warmup_steps=100
)
# Save fine-tuned model
model.save('models/medical-embedder-v1')
Use in Node.js:
class FineTunedMedicalEmbedder {
constructor() {
this.modelPath = "./models/medical-embedder-v1";
}
async embed(text) {
// Load fine-tuned checkpoint
const model = await AutoModel.from_pretrained(this.modelPath);
// ... embedding logic
}
}
Pattern 3: Hybrid Embedders
Combine multiple embedding sources for enhanced retrieval:
class HybridEmbedder {
constructor(options) {
this.semanticEmbedder = new SentenceTransformerEmbedder();
this.lexicalEmbedder = new TFIDFEmbedder();
this.semanticWeight = options.semanticWeight || 0.7;
this.lexicalWeight = options.lexicalWeight || 0.3;
}
async embed(text) {
const [semantic, lexical] = await Promise.all([
this.semanticEmbedder.embed(text),
this.lexicalEmbedder.embed(text),
]);
// Weighted combination
return this.combine(semantic, lexical);
}
combine(semantic, lexical) {
// Normalize dimensions if different
const maxDim = Math.max(semantic.length, lexical.length);
const combined = new Array(maxDim).fill(0);
for (let i = 0; i < semantic.length; i++) {
combined[i] += semantic[i] * this.semanticWeight;
}
for (let i = 0; i < lexical.length; i++) {
combined[i] += lexical[i] * this.lexicalWeight;
}
return combined;
}
}
Optimization Techniques
1. Caching
Cache embeddings for frequently accessed texts:
class CachedEmbedder {
constructor(baseEmbedder, options = {}) {
this.baseEmbedder = baseEmbedder;
this.cache = new Map();
this.maxCacheSize = options.maxCacheSize || 10000;
}
async embed(text) {
const hash = this.hash(text);
if (this.cache.has(hash)) {
return this.cache.get(hash);
}
const embedding = await this.baseEmbedder.embed(text);
// Implement LRU eviction
if (this.cache.size >= this.maxCacheSize) {
const firstKey = this.cache.keys().next().value;
this.cache.delete(firstKey);
}
this.cache.set(hash, embedding);
return embedding;
}
hash(text) {
// Use fast hash for cache keys
return require("crypto").createHash("sha256").update(text).digest("hex");
}
}
2. Batch Processing
Maximize GPU utilization with batching:
class BatchOptimizedEmbedder {
constructor(baseEmbedder, options = {}) {
this.baseEmbedder = baseEmbedder;
this.batchSize = options.batchSize || 32;
this.queue = [];
this.processing = false;
}
async embed(text) {
return new Promise((resolve, reject) => {
this.queue.push({ text, resolve, reject });
if (!this.processing) {
this.processBatch();
}
});
}
async processBatch() {
this.processing = true;
while (this.queue.length > 0) {
const batch = this.queue.splice(0, this.batchSize);
const texts = batch.map((item) => item.text);
try {
const embeddings = await this.baseEmbedder.embedBatch(texts);
batch.forEach((item, index) => {
item.resolve(embeddings[index]);
});
} catch (error) {
batch.forEach((item) => item.reject(error));
}
}
this.processing = false;
}
}
3. Quantization
Reduce memory footprint with int8 quantization:
class QuantizedEmbedder {
async embed(text) {
const embedding = await this.baseEmbedder.embed(text);
// Quantize to int8 (-128 to 127)
const min = Math.min(...embedding);
const max = Math.max(...embedding);
const scale = 255 / (max - min);
return embedding.map((val) => {
const quantized = Math.round((val - min) * scale) - 128;
return Math.max(-128, Math.min(127, quantized));
});
}
dequantize(quantized, min, max) {
const scale = (max - min) / 255;
return quantized.map((val) => (val + 128) * scale + min);
}
}
Integration with RAG Pipeline Utils
const { createRagPipeline } = require("@devilsdev/rag-pipeline-utils");
const CustomEmbedder = require("./embedders/custom-embedder");
// Create embedder with optimizations
const embedder = new CachedEmbedder(
new BatchOptimizedEmbedder(
new SentenceTransformerEmbedder({
modelName: "sentence-transformers/all-mpnet-base-v2",
}),
),
);
// Use in pipeline
const pipeline = createRagPipeline({
embedder,
retriever: myRetriever,
llm: myLLM,
});
// Embed documents
await pipeline.ingest("./documents");
// Query uses same embedder
const result = await pipeline.query("How does authentication work?");
Performance Benchmarks
Comparative performance on 10,000 documents:
| Embedder | Throughput | Latency (p95) | Memory |
|---|---|---|---|
| OpenAI API | 50 texts/sec | 200ms | Minimal |
| Sentence Transformers (CPU) | 100 texts/sec | 100ms | 500MB |
| Sentence Transformers (GPU) | 500 texts/sec | 20ms | 2GB |
| Cached + Batched | 1000 texts/sec | 10ms | 1GB |
Best Practices
- Choose the Right Base Model: Start with proven models like all-mpnet-base-v2 or all-MiniLM-L6-v2
- Fine-tune When Necessary: Only fine-tune when general models underperform
- Implement Caching: Cache embeddings for static documents
- Use Batch Processing: Always implement batch methods for large-scale operations
- Monitor Quality: Track retrieval metrics to ensure embedding quality
- Version Models: Use model versioning to enable A/B testing
Conclusion
Custom embedders provide flexibility and control for production RAG systems. By understanding the embedder contract, implementation patterns, and optimization techniques, you can build embedders tailored to your domain while maintaining high performance and cost efficiency.
